mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-23 16:37:33 +03:00
Compare commits
53 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ec2b787ebe | ||
|
|
d3ac030a5d | ||
|
|
49bfddeca1 | ||
|
|
bd3f1d9d65 | ||
|
|
23c9182ce8 | ||
|
|
81bc4d3ddc | ||
|
|
f40a80b4f3 | ||
|
|
db9d8aa428 | ||
|
|
ccb87fa3ee | ||
|
|
3306dbaef7 | ||
|
|
990e4d9698 | ||
|
|
212f4521b0 | ||
|
|
568aec82d2 | ||
|
|
2bcdddd5e3 | ||
|
|
eac9c6ea83 | ||
|
|
29b28a9824 | ||
|
|
cea560f483 | ||
|
|
b1c70e2e54 | ||
|
|
e6ec21e62f | ||
|
|
4cb7e0bd61 | ||
|
|
149b2493c0 | ||
|
|
b31b30f31d | ||
|
|
58c81f7e81 | ||
|
|
fb78ad29bb | ||
|
|
e06c3ab2bc | ||
|
|
dc6592431b | ||
|
|
3adbef7776 | ||
|
|
ab9d4c3678 | ||
|
|
1af9dab32b | ||
|
|
6d99b44c7e | ||
|
|
464fd0e71f | ||
|
|
21c8045214 | ||
|
|
c46583b86b | ||
|
|
c1b911654a | ||
|
|
b739738dad | ||
|
|
a0bbcdd9b6 | ||
|
|
6c72646a61 | ||
|
|
340807273b | ||
|
|
26c9ce1288 | ||
|
|
76f2dc70c3 | ||
|
|
900efd531d | ||
|
|
74c42ee1f4 | ||
|
|
b49d8b8757 | ||
|
|
5e54d51b19 | ||
|
|
c1258830b2 | ||
|
|
922b90e567 | ||
|
|
f071ce67c9 | ||
|
|
4065c1a3a6 | ||
|
|
1e64534570 | ||
|
|
cd708db0cc | ||
|
|
512bba6ee0 | ||
|
|
b486c17b3e | ||
|
|
1b9bbaa357 |
87
.github/workflows/ai-issues.yml
vendored
Normal file
87
.github/workflows/ai-issues.yml
vendored
Normal file
@@ -0,0 +1,87 @@
|
||||
name: AI review (issues)
|
||||
|
||||
on:
|
||||
issues:
|
||||
types: [opened]
|
||||
|
||||
jobs:
|
||||
find-related:
|
||||
if: github.event.action == 'opened'
|
||||
runs-on: [self-hosted, opencode]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Find related
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
OPENCODE_PERMISSION: |
|
||||
{
|
||||
"bash": {
|
||||
"*": "deny",
|
||||
"gh issue*": "allow",
|
||||
"gh search issues*": "allow"
|
||||
},
|
||||
"webfetch": "deny"
|
||||
}
|
||||
run: |
|
||||
rm AGENTS.md
|
||||
rm CLAUDE.md
|
||||
|
||||
timeout 5m opencode run -m llama.cpp-dgx/ai-review-issues-find-similar --thinking "A new issue has been created:
|
||||
|
||||
Issue number: ${{ github.event.issue.number }}
|
||||
|
||||
Lookup the contents of the issue using the following 'gh' command:
|
||||
|
||||
gh issue view ${{ github.event.issue.number }} --json title,body,url,number
|
||||
|
||||
Next, perform the following task and then post a SINGLE comment (if needed).
|
||||
|
||||
---
|
||||
|
||||
TASK : FIND RELATED ISSUES
|
||||
|
||||
Using the 'gh' CLI tool, search through existing issues on Github.
|
||||
Find related or similar issues to the newly created one and list them.
|
||||
Do not list the new issue itself (it is #${{ github.event.issue.number }}).
|
||||
|
||||
Consider:
|
||||
1. Similar titles or descriptions
|
||||
2. Same error messages or symptoms
|
||||
3. Related functionality or components
|
||||
4. Similar feature requests
|
||||
|
||||
---
|
||||
|
||||
POSTING YOUR COMMENT:
|
||||
|
||||
Based on your findings, post a SINGLE comment on issue #${{ github.event.issue.number }}. Build the comment as follows:
|
||||
|
||||
- If no related issues were found, do NOT comment at all.
|
||||
- If related issues were found, include a section listing them with links using the following format:
|
||||
|
||||
[comment]
|
||||
This issue might be similar or related to the following issue(s):
|
||||
|
||||
- #[related_issue_number]: [brief description of how they are related]
|
||||
- #[related_issue_number]: [brief description of how they are related]
|
||||
...
|
||||
|
||||
_This comment was auto-generated locally using **$GA_ENGINE** on **$GA_MACHINE**_
|
||||
[/comment]
|
||||
|
||||
Remember:
|
||||
- Do not include the comment tags in your actual comment.
|
||||
- Post at most ONE comment combining all findings.
|
||||
- If you didn't find issues that are related enough, post nothing.
|
||||
- You have access only to the 'gh' CLI tool - don't try to use other tools.
|
||||
- If the output from a tool call is too long, try to limit down the search.
|
||||
"
|
||||
80
.github/workflows/hip-quality-check.yml
vendored
Normal file
80
.github/workflows/hip-quality-check.yml
vendored
Normal file
@@ -0,0 +1,80 @@
|
||||
name: HIP quality check
|
||||
|
||||
on:
|
||||
workflow_dispatch: # allows manual triggering
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths: [
|
||||
'.github/workflows/hip-quality-check.yml',
|
||||
'**/*.cu',
|
||||
'**/*.cuh'
|
||||
]
|
||||
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
paths: [
|
||||
'.github/workflows/hip-quality-check.yml',
|
||||
'**/*.cu',
|
||||
'**/*.cuh'
|
||||
]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
GGML_NLOOP: 3
|
||||
GGML_N_THREADS: 1
|
||||
LLAMA_LOG_COLORS: 1
|
||||
LLAMA_LOG_PREFIX: 1
|
||||
LLAMA_LOG_TIMESTAMPS: 1
|
||||
|
||||
jobs:
|
||||
ubuntu-22-hip-quality-check:
|
||||
runs-on: ubuntu-22.04
|
||||
container: rocm/dev-ubuntu-22.04:7.2
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libssl-dev python3
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ubuntu-22-hip-quality-check
|
||||
evict-old-files: 1d
|
||||
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
|
||||
|
||||
- name: Build with Werror
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -S . \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DGPU_TARGETS=gfx908 \
|
||||
-DGGML_HIP=ON \
|
||||
-DGGML_HIP_EXPORT_METRICS=Off \
|
||||
-DCMAKE_HIP_FLAGS="-Werror -Wno-tautological-compare" \
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
cd build
|
||||
make -j $(nproc)
|
||||
|
||||
- name: Check for major VGPR spills
|
||||
id: vgpr_check
|
||||
run: |
|
||||
cmake -B build -S . \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DGPU_TARGETS=gfx908 \
|
||||
-DGGML_HIP=ON \
|
||||
-DGGML_HIP_EXPORT_METRICS=On \
|
||||
-DCMAKE_HIP_FLAGS="" \
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
cd build
|
||||
make -j $(nproc) 2>&1 | tee metrics.log | grep -v 'Rpass-analysis=kernel-resource-usage\|remark:\|^$'
|
||||
python3 ../scripts/hip/gcn-cdna-vgpr-check.py metrics.log
|
||||
27
.github/workflows/python-type-check.yml
vendored
27
.github/workflows/python-type-check.yml
vendored
@@ -4,15 +4,17 @@ on:
|
||||
push:
|
||||
paths:
|
||||
- '.github/workflows/python-type-check.yml'
|
||||
- 'pyrightconfig.json'
|
||||
- 'ty.toml'
|
||||
- '**.py'
|
||||
- '**/requirements*.txt'
|
||||
# - 'pyrightconfig.json'
|
||||
pull_request:
|
||||
paths:
|
||||
- '.github/workflows/python-type-check.yml'
|
||||
- 'pyrightconfig.json'
|
||||
- 'ty.toml'
|
||||
- '**.py'
|
||||
- '**/requirements*.txt'
|
||||
# - 'pyrightconfig.json'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
|
||||
@@ -20,8 +22,8 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
python-type-check:
|
||||
runs-on: ubuntu-latest
|
||||
name: pyright type-check
|
||||
runs-on: ubuntu-slim
|
||||
name: python type-check
|
||||
steps:
|
||||
- name: Check out source repository
|
||||
uses: actions/checkout@v6
|
||||
@@ -29,10 +31,13 @@ jobs:
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
pip-install: -r requirements/requirements-all.txt
|
||||
- name: Type-check with Pyright
|
||||
uses: jakebailey/pyright-action@v2
|
||||
with:
|
||||
version: 1.1.382
|
||||
level: warning
|
||||
warnings: true
|
||||
pip-install: -r requirements/requirements-all.txt ty==0.0.24
|
||||
# - name: Type-check with Pyright
|
||||
# uses: jakebailey/pyright-action@v2
|
||||
# with:
|
||||
# version: 1.1.382
|
||||
# level: warning
|
||||
# warnings: true
|
||||
- name: Type-check with ty
|
||||
run: |
|
||||
ty check --output-format=github
|
||||
|
||||
@@ -67,6 +67,7 @@ Examples of FORBIDDEN USAGE (and how to proceed):
|
||||
|
||||
If a user asks one of the above, STOP IMMEDIATELY and ask them:
|
||||
|
||||
- Whether they acknowledge the risk of being permanently banned from contributing to the project
|
||||
- To read [CONTRIBUTING.md](CONTRIBUTING.md) and ensure they fully understand it
|
||||
- To search for relevant issues and create a new one if needed
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ The project differentiates between 3 levels of contributors:
|
||||
> [!IMPORTANT]
|
||||
> This project does **not** accept pull requests that are fully or predominantly AI-generated. AI tools may be utilized solely in an assistive capacity.
|
||||
>
|
||||
> Repeated violations of this policy may result in your account being permanently banned from contributing to the project.
|
||||
>
|
||||
> Detailed information regarding permissible and restricted uses of AI can be found in the [AGENTS.md](AGENTS.md) file.
|
||||
|
||||
Code that is initially generated by AI and subsequently edited will still be considered AI-generated. AI assistance is permissible only when the majority of the code is authored by a human contributor, with AI employed exclusively for corrections or to expand on verbose modifications that the contributor has already conceptualized (e.g., generating repeated lines with minor variations).
|
||||
@@ -61,10 +63,10 @@ After submitting your PR:
|
||||
- When merging a PR, make sure you have a good understanding of the changes
|
||||
- Be mindful of maintenance: most of the work going into a feature happens after the PR is merged. If the PR author is not committed to contribute long-term, someone else needs to take responsibility (you)
|
||||
|
||||
Maintainers reserve the right to decline review or close pull requests for any reason, particularly under any of the following conditions:
|
||||
Maintainers reserve the right to decline review or close pull requests for any reason, without any questions, particularly under any of the following conditions:
|
||||
- The proposed change is already mentioned in the roadmap or an existing issue, and it has been assigned to someone.
|
||||
- The pull request duplicates an existing one.
|
||||
- The contributor fails to adhere to this contributing guide.
|
||||
- The contributor fails to adhere to this contributing guide or the AI policy.
|
||||
|
||||
# Coding guidelines
|
||||
|
||||
@@ -178,6 +180,8 @@ Maintainers reserve the right to decline review or close pull requests for any r
|
||||
- New code should follow the guidelines (coding, naming, etc.) outlined in this document. Exceptions are allowed in isolated, backend-specific parts of the code that do not interface directly with the `ggml` interfaces.
|
||||
_(NOTE: for legacy reasons, existing code is not required to follow this guideline)_
|
||||
|
||||
- For changes in server, please make sure to refer to the [server development documentation](./tools/server/README-dev.md)
|
||||
|
||||
# Documentation
|
||||
|
||||
- Documentation is a community effort
|
||||
|
||||
@@ -1830,23 +1830,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--grammar"}, "GRAMMAR",
|
||||
string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()),
|
||||
"BNF-like grammar to constrain generations (see samples in grammars/ dir)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.grammar = value;
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, value};
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--grammar-file"}, "FNAME",
|
||||
"file to read grammar from",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.grammar = read_file(value);
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, read_file(value)};
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"-j", "--json-schema"}, "SCHEMA",
|
||||
"JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.grammar = json_schema_to_grammar(json::parse(value));
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, json_schema_to_grammar(json::parse(value))};
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
@@ -1863,7 +1863,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
std::istreambuf_iterator<char>(),
|
||||
std::back_inserter(schema)
|
||||
);
|
||||
params.sampling.grammar = json_schema_to_grammar(json::parse(schema));
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, json_schema_to_grammar(json::parse(schema))};
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
@@ -2583,7 +2583,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
|
||||
"Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"
|
||||
"mmproj is also downloaded automatically if available. to disable, add --no-mmproj\n"
|
||||
"example: unsloth/phi-4-GGUF:q4_k_m\n"
|
||||
"example: ggml-org/GLM-4.7-Flash-GGUF:Q4_K_M\n"
|
||||
"(default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.model.hf_repo = value;
|
||||
@@ -3494,7 +3494,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
throw std::invalid_argument("unknown speculative decoding type without draft model");
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SPEC_TYPE"));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-size-n"}, "N",
|
||||
string_format("ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)", params.speculative.ngram_size_n),
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include "chat-auto-parser-helpers.h"
|
||||
#include "chat-auto-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "chat.h"
|
||||
@@ -23,13 +24,13 @@ static void foreach_function(const json & tools, const std::function<void(const
|
||||
|
||||
namespace autoparser {
|
||||
|
||||
parser_build_context::parser_build_context(common_chat_peg_builder & p, const templates_params & inputs) :
|
||||
parser_build_context::parser_build_context(common_chat_peg_builder & p, const generation_params & inputs) :
|
||||
p(p),
|
||||
inputs(inputs),
|
||||
reasoning_parser(p.eps()) {}
|
||||
|
||||
common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl,
|
||||
const struct templates_params & inputs) {
|
||||
const struct generation_params & inputs) {
|
||||
// Run differential analysis to extract template structure
|
||||
struct autoparser autoparser;
|
||||
autoparser.analyze_template(tmpl);
|
||||
@@ -37,17 +38,16 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
|
||||
}
|
||||
|
||||
common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl,
|
||||
const struct templates_params & inputs,
|
||||
const struct generation_params & inputs,
|
||||
const autoparser & autoparser) {
|
||||
// Build the parser using the analysis results
|
||||
auto parser = autoparser.build_parser(inputs);
|
||||
|
||||
// Create the result structure
|
||||
common_chat_params data;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = autoparser.preserved_tokens;
|
||||
data.parser = parser.save();
|
||||
|
||||
auto parser = autoparser.build_parser(inputs);
|
||||
data.parser = parser.save();
|
||||
|
||||
// Build grammar if tools are present
|
||||
bool has_tools =
|
||||
@@ -82,44 +82,38 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
|
||||
return data;
|
||||
}
|
||||
|
||||
common_peg_arena autoparser::build_parser(const templates_params & inputs) const {
|
||||
common_peg_arena autoparser::build_parser(const generation_params & inputs) const {
|
||||
if (!analysis_complete) {
|
||||
throw std::invalid_argument("Cannot call build_parser on autoparser without performing analysis first, call analyze_template(...)");
|
||||
}
|
||||
return build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
// If the template uses Python dict format (single-quoted strings in JSON structures),
|
||||
// pre-register a json-string rule that accepts both quote styles. This must happen
|
||||
// before any call to p.json() so that all JSON parsing inherits the flexible rule.
|
||||
if (tools.format.uses_python_dicts) {
|
||||
p.rule("json-string", p.quoted_string());
|
||||
}
|
||||
|
||||
parser_build_context ctx(p, inputs);
|
||||
bool extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
bool enable_thinking = inputs.enable_thinking;
|
||||
|
||||
ctx.extracting_reasoning = extract_reasoning && enable_thinking && reasoning.mode != reasoning_mode::NONE;
|
||||
ctx.extracting_reasoning = extract_reasoning && reasoning.mode != reasoning_mode::NONE;
|
||||
ctx.content = &content;
|
||||
|
||||
// Build reasoning parser
|
||||
ctx.reasoning_parser = reasoning.build_parser(ctx);
|
||||
|
||||
auto parser = p.eps();
|
||||
|
||||
bool has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
bool has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
|
||||
|
||||
if (has_response_format) {
|
||||
auto response_format = p.rule("response-format", p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
|
||||
return ctx.reasoning_parser + p.space() + p.choice({
|
||||
parser = ctx.reasoning_parser + p.space() + p.choice({
|
||||
p.literal("```json") + p.space() + response_format + p.space() + p.literal("```"),
|
||||
response_format
|
||||
}) + p.end();
|
||||
} else if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) {
|
||||
parser = tools.build_parser(ctx);
|
||||
} else {
|
||||
parser = content.build_parser(ctx);
|
||||
}
|
||||
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) {
|
||||
return tools.build_parser(ctx);
|
||||
}
|
||||
|
||||
return content.build_parser(ctx);
|
||||
parser = wrap_for_generation_prompt(p, parser, inputs, reasoning.start);
|
||||
return parser;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -130,24 +124,15 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co
|
||||
return p.eps();
|
||||
}
|
||||
|
||||
bool thinking_forced_open = (mode == reasoning_mode::FORCED_OPEN);
|
||||
bool thinking_forced_closed = (mode == reasoning_mode::FORCED_CLOSED);
|
||||
|
||||
if (thinking_forced_open || thinking_forced_closed) {
|
||||
// Thinking is forced open OR forced closed with enable_thinking=true
|
||||
// In both cases, expect only the closing tag (opening was in template)
|
||||
// However, since we might have incorrectly detected the open/close pattern,
|
||||
// we admit an optional starting marker
|
||||
return p.optional(p.literal(start)) + p.reasoning(p.until(end)) + end;
|
||||
}
|
||||
if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) {
|
||||
// Standard tag-based reasoning OR tools-only mode (reasoning appears with tools)
|
||||
// Both use the same tag-based pattern if markers are available
|
||||
if (!start.empty() && !end.empty()) {
|
||||
return p.optional(start + p.reasoning(p.until(end)) + end);
|
||||
if (!end.empty()) {
|
||||
if (!start.empty()) {
|
||||
// Standard tag-based: optional(<think>reasoning</think>)
|
||||
return p.optional(start + p.reasoning(p.until(end)) + end + p.space());
|
||||
}
|
||||
// Delimiter-style (empty start)
|
||||
return p.optional(p.reasoning(p.until(end)) + end + p.space());
|
||||
}
|
||||
} else if (mode == reasoning_mode::DELIMITER) {
|
||||
return p.optional(p.reasoning(p.until(end)) + end);
|
||||
}
|
||||
|
||||
return p.eps();
|
||||
@@ -335,7 +320,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
"tool-" + name + "-arg-" + param_name + "-schema",
|
||||
param_schema, true)) :
|
||||
p.tool_arg_json_value(p.schema(
|
||||
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, format.uses_python_dicts)) +
|
||||
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
|
||||
p.space()) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)));
|
||||
|
||||
@@ -384,7 +369,9 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section) + p.space() + args_seq;
|
||||
matched_atomic = true;
|
||||
} else if (!arguments.name_prefix.empty() && properties.size() > 0) {
|
||||
} else if (!arguments.name_prefix.empty() && !required_parsers.empty()) {
|
||||
// Only peek for an arg tag when there are required args that must follow.
|
||||
// When all args are optional, the model may emit no arg tags at all (#20650).
|
||||
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.space() + p.peek(p.literal(arguments.name_prefix))) + args_seq;
|
||||
matched_atomic = true;
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
#include "chat-auto-parser-helpers.h"
|
||||
|
||||
#include "chat-auto-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "chat.h"
|
||||
#include "log.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include <cctype>
|
||||
#include <numeric>
|
||||
@@ -186,6 +188,21 @@ diff_split calculate_diff_split(const std::string & left, const std::string & ri
|
||||
result.suffix = "";
|
||||
// pick prefix = all as representation
|
||||
}
|
||||
|
||||
// When left has no unique content (result.left is empty), left is entirely
|
||||
// shared with right. The simultaneous prefix/suffix segment matching can
|
||||
// incorrectly consume trailing segments of left as suffix when those same
|
||||
// segments also appear at the end of right (e.g. "\n" at the end of both
|
||||
// the shared content and the generation prompt). This rotates the diff.
|
||||
// Fix: if left is a prefix of right, enforce that directly.
|
||||
if (result.left.empty() && !result.right.empty() &&
|
||||
left.size() <= right.size() &&
|
||||
right.substr(0, left.size()) == left) {
|
||||
result.prefix = left;
|
||||
result.suffix = "";
|
||||
result.right = right.substr(left.size());
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -291,10 +308,26 @@ std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segm
|
||||
return result;
|
||||
}
|
||||
|
||||
common_peg_parser wrap_for_generation_prompt(common_chat_peg_builder & p,
|
||||
const common_peg_parser & prs,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::string & reasoning_start) {
|
||||
auto parser = prs;
|
||||
if (!inputs.generation_prompt.empty()) {
|
||||
size_t end_pos = inputs.generation_prompt.size();
|
||||
if (!reasoning_start.empty() && inputs.generation_prompt.find(reasoning_start) != std::string::npos) {
|
||||
end_pos = inputs.generation_prompt.find(reasoning_start);
|
||||
}
|
||||
std::string cut_genprompt = inputs.generation_prompt.substr(0, end_pos);
|
||||
parser = p.literal(cut_genprompt) + parser;
|
||||
}
|
||||
return parser;
|
||||
}
|
||||
|
||||
namespace autoparser {
|
||||
|
||||
std::string apply_template(const common_chat_template & tmpl, const template_params & params) {
|
||||
templates_params tmpl_params;
|
||||
generation_params tmpl_params;
|
||||
tmpl_params.messages = params.messages;
|
||||
tmpl_params.tools = params.tools;
|
||||
tmpl_params.add_generation_prompt = params.add_generation_prompt;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "chat-auto-parser.h"
|
||||
#include "peg-parser.h"
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
@@ -57,6 +58,11 @@ std::vector<segment> segmentize_markers(const std::string & text);
|
||||
// (MARKER, "</function>"), (MARKER, "</tool_call>") ]
|
||||
std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segments);
|
||||
|
||||
// Wrap parser with generation prompt parser
|
||||
common_peg_parser wrap_for_generation_prompt(common_chat_peg_builder & p,
|
||||
const common_peg_parser & prs,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::string & reasoning_start = {});
|
||||
namespace autoparser {
|
||||
|
||||
// Apply a template with the given parameters, returning the rendered string (empty on failure)
|
||||
|
||||
@@ -50,7 +50,7 @@ namespace autoparser {
|
||||
// High-level params for parser generation
|
||||
// ============================================================================
|
||||
|
||||
struct templates_params {
|
||||
struct generation_params {
|
||||
json messages;
|
||||
json tools;
|
||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
@@ -62,6 +62,7 @@ struct templates_params {
|
||||
bool add_generation_prompt = false;
|
||||
bool enable_thinking = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
std::string generation_prompt;
|
||||
json extra_context;
|
||||
bool add_bos = false;
|
||||
bool add_eos = false;
|
||||
@@ -77,11 +78,7 @@ struct templates_params {
|
||||
// Reasoning handling mode (derived from R1-R3 comparisons)
|
||||
enum class reasoning_mode {
|
||||
NONE, // No reasoning markers detected
|
||||
TAG_BASED, // Standard tag-based: <think>...</think>
|
||||
DELIMITER, // Delimiter-based: [BEGIN FINAL RESPONSE] (reasoning ends at delimiter)
|
||||
FORCED_OPEN, // Template ends with open reasoning tag (empty start, non-empty end)
|
||||
FORCED_CLOSED, // Template ends with open reasoning tag on enabled thinking but
|
||||
// with both opened and closed tag for disabled thinking
|
||||
TAG_BASED, // Tag-based: <think>...</think> (start can be empty for delimiter-style)
|
||||
TOOLS_ONLY // Only reason on tool calls, not on normal content
|
||||
};
|
||||
|
||||
@@ -91,12 +88,6 @@ inline std::ostream & operator<<(std::ostream & os, const reasoning_mode & mode)
|
||||
return os << "NONE";
|
||||
case reasoning_mode::TAG_BASED:
|
||||
return os << "TAG_BASED";
|
||||
case reasoning_mode::DELIMITER:
|
||||
return os << "DELIMITER";
|
||||
case reasoning_mode::FORCED_OPEN:
|
||||
return os << "FORCED_OPEN";
|
||||
case reasoning_mode::FORCED_CLOSED:
|
||||
return os << "FORCED_CLOSED";
|
||||
case reasoning_mode::TOOLS_ONLY:
|
||||
return os << "TOOLS_ONLY";
|
||||
default:
|
||||
@@ -184,7 +175,6 @@ struct tool_format_analysis {
|
||||
|
||||
bool fun_name_is_key = false; // In JSON format function name is JSON key, i.e. { "<funname>": { ... arguments ... } }
|
||||
bool tools_array_wrapped = false; // Tool calls wrapped in JSON array [...]
|
||||
bool uses_python_dicts = false; // Tool call args use Python dict format (single-quoted strings)
|
||||
|
||||
std::string function_field = "function";
|
||||
std::string name_field = "name";
|
||||
@@ -225,12 +215,12 @@ struct analyze_content;
|
||||
|
||||
struct parser_build_context {
|
||||
common_chat_peg_builder & p;
|
||||
const templates_params & inputs;
|
||||
const generation_params & inputs;
|
||||
common_peg_parser reasoning_parser;
|
||||
bool extracting_reasoning = false;
|
||||
const analyze_content * content = nullptr;
|
||||
|
||||
parser_build_context(common_chat_peg_builder & p, const templates_params & inputs);
|
||||
parser_build_context(common_chat_peg_builder & p, const generation_params & inputs);
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
@@ -260,6 +250,7 @@ struct analyze_reasoning : analyze_base {
|
||||
|
||||
analyze_reasoning() = default;
|
||||
analyze_reasoning(const common_chat_template & tmpl, bool supports_tools);
|
||||
analyze_reasoning(std::string start_, std::string end_) : start(std::move(start_)), end(std::move(end_)) {}
|
||||
|
||||
common_peg_parser build_parser(parser_build_context & ctx) const override;
|
||||
|
||||
@@ -381,7 +372,7 @@ struct autoparser {
|
||||
void analyze_template(const common_chat_template & tmpl);
|
||||
|
||||
// Build the PEG parser for this template
|
||||
common_peg_arena build_parser(const templates_params & inputs) const;
|
||||
common_peg_arena build_parser(const generation_params & inputs) const;
|
||||
|
||||
private:
|
||||
// Collect tokens from entire analysis to preserve
|
||||
@@ -395,10 +386,10 @@ struct autoparser {
|
||||
class peg_generator {
|
||||
public:
|
||||
static common_chat_params generate_parser(const common_chat_template & tmpl,
|
||||
const struct templates_params & inputs);
|
||||
const struct generation_params & inputs);
|
||||
|
||||
static common_chat_params generate_parser(const common_chat_template & tmpl,
|
||||
const struct templates_params & inputs,
|
||||
const struct generation_params & inputs,
|
||||
const autoparser & autoparser);
|
||||
};
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "chat-auto-parser-helpers.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "peg-parser.h"
|
||||
@@ -31,8 +32,9 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
|
||||
[](const common_chat_template & tmpl, autoparser & analysis) -> void {
|
||||
if (tmpl.src.find("content.split('</think>')") != std::string::npos &&
|
||||
tmpl.src.find("reasoning_content") == std::string::npos &&
|
||||
tmpl.src.find("<SPECIAL_12>") == std::string::npos &&
|
||||
analysis.reasoning.mode == reasoning_mode::NONE) {
|
||||
analysis.reasoning.mode = reasoning_mode::FORCED_OPEN;
|
||||
analysis.reasoning.mode = reasoning_mode::TAG_BASED;
|
||||
analysis.reasoning.start = "<think>";
|
||||
analysis.reasoning.end = "</think>";
|
||||
analysis.preserved_tokens.push_back("<think>");
|
||||
@@ -185,7 +187,6 @@ void autoparser::analyze_template(const common_chat_template & tmpl) {
|
||||
LOG_DBG("func_name_prefix: '%s'\n", tools.function.name_prefix.c_str());
|
||||
LOG_DBG("func_name_suffix: '%s'\n", tools.function.name_suffix.c_str());
|
||||
LOG_DBG("func_close: '%s'\n", tools.function.close.c_str());
|
||||
LOG_DBG("python_dict_format: %s\n", tools.format.uses_python_dicts ? "true" : "false");
|
||||
LOG_DBG("arg_name_prefix: '%s'\n", tools.arguments.name_prefix.c_str());
|
||||
LOG_DBG("arg_name_suffix: '%s'\n", tools.arguments.name_suffix.c_str());
|
||||
LOG_DBG("arg_value_prefix: '%s'\n", tools.arguments.value_prefix.c_str());
|
||||
@@ -295,16 +296,12 @@ void analyze_reasoning::compare_reasoning_presence() {
|
||||
}
|
||||
if (result.result.success()) {
|
||||
if (!result.tags["pre"].empty() && !result.tags["post"].empty()) {
|
||||
if (parser_wrapped.parse_anywhere_and_extract(diff.right).result.success()) { // both tags in the diff = no forced close
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
} else {
|
||||
mode = reasoning_mode::FORCED_CLOSED;
|
||||
}
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
start = trim_whitespace(result.tags["pre"]);
|
||||
end = result.tags["post"];
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
} else if (!result.tags["post"].empty()) {
|
||||
mode = reasoning_mode::DELIMITER;
|
||||
end = result.tags["post"];
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -331,53 +328,30 @@ void analyze_reasoning::compare_thinking_enabled() {
|
||||
const auto & diff = comparison->diff;
|
||||
|
||||
std::string left_trimmed = trim_whitespace(diff.left);
|
||||
std::string right_trimmed = trim_whitespace(diff.right);
|
||||
|
||||
if (left_trimmed.empty() && !diff.right.empty()) {
|
||||
std::string right_trimmed = trim_whitespace(diff.right);
|
||||
|
||||
if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) {
|
||||
if (start.empty()) {
|
||||
start = right_trimmed;
|
||||
mode = reasoning_mode::FORCED_OPEN;
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
}
|
||||
}
|
||||
} else if (right_trimmed.empty() && !diff.left.empty()) {
|
||||
if (!left_trimmed.empty() && string_ends_with(comparison->output_A, left_trimmed)) {
|
||||
if (end.empty()) {
|
||||
auto seg = prune_whitespace_segments(segmentize_markers(comparison->output_A));
|
||||
if (seg.size() >= 2 && seg[seg.size() - 1].value == left_trimmed && seg[seg.size() - 2].type == segment_type::MARKER) {
|
||||
start = seg[seg.size() - 2].value;
|
||||
}
|
||||
end = left_trimmed;
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (start.empty() && !end.empty()) {
|
||||
mode = reasoning_mode::DELIMITER;
|
||||
}
|
||||
|
||||
// Check for FORCED_CLOSED: when enable_thinking=false produces both start and end markers,
|
||||
// but enable_thinking=true produces only the start marker
|
||||
if (!comparison->output_A.empty() && !comparison->output_B.empty()) {
|
||||
auto parser_start = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.literal(start) + p.space() + p.literal(end) + p.rest();
|
||||
});
|
||||
auto parser_start_end = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.tag("pre", p.literal(start)) + p.space() + p.negate(p.literal(end)) + p.rest();
|
||||
});
|
||||
if (!start.empty() && parser_start_end.parse_anywhere_and_extract(comparison->output_A).result.success() &&
|
||||
parser_start.parse_anywhere_and_extract(comparison->output_B).result.success()) {
|
||||
mode = reasoning_mode::FORCED_CLOSED;
|
||||
} else if (!end.empty()) { // we extract the starting marker now since we didn't get it earlier
|
||||
auto result = parser_start_end.parse_anywhere_and_extract(comparison->output_A);
|
||||
if (result.result.success()) {
|
||||
start = result.tags["pre"];
|
||||
mode = reasoning_mode::FORCED_CLOSED;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (start.empty() && end.empty()) { // we might still have the case of "just open" and "just close"
|
||||
if (!diff.left.empty() && !diff.right.empty()) {
|
||||
auto seg_A = segmentize_markers(trim_trailing_whitespace(diff.left));
|
||||
auto seg_B = segmentize_markers(trim_trailing_whitespace(diff.right));
|
||||
if (seg_A.size() == 1 && seg_B.size() == 1) {
|
||||
mode = reasoning_mode::FORCED_CLOSED;
|
||||
start = seg_B[0].value;
|
||||
end = seg_A[0].value;
|
||||
}
|
||||
}
|
||||
if (mode == reasoning_mode::NONE && start.empty() && !end.empty()) {
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -426,16 +400,16 @@ void analyze_reasoning::compare_reasoning_scope() {
|
||||
auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B);
|
||||
if (result.result.success()) {
|
||||
start = result.tags["pre"];
|
||||
end = result.tags["post"];
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
} else {
|
||||
auto parser_delimiter = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space())));
|
||||
});
|
||||
result = parser_delimiter.parse_anywhere_and_extract(comparison->output_B);
|
||||
if (result.result.success()) {
|
||||
end = result.tags["post"];
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
} else {
|
||||
LOG_DBG(ANSI_ORANGE "%s: Unable to extracft reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__);
|
||||
LOG_DBG(ANSI_ORANGE "%s: Unable to extract reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__);
|
||||
mode = reasoning_mode::NONE;
|
||||
}
|
||||
}
|
||||
@@ -600,33 +574,23 @@ void analyze_tools::analyze_tool_call_format(const std::string & haystack,
|
||||
return;
|
||||
}
|
||||
|
||||
enum class json_quote_style { NONE, DOUBLE_QUOTES, SINGLE_QUOTES };
|
||||
|
||||
auto in_json_haystack = [&haystack](const std::string & needle) -> json_quote_style {
|
||||
auto in_json_haystack = [&haystack](const std::string & needle) -> bool {
|
||||
auto parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.choice({ p.literal("{"), p.literal(":") }) << p.choice({
|
||||
p.tag("sq", p.literal("'") + p.literal(needle) + p.literal("'")),
|
||||
p.tag("dq", p.literal("\"") + p.literal(needle) + p.literal("\"")) });
|
||||
});
|
||||
auto result = parser.parse_anywhere_and_extract(haystack);
|
||||
if (!result.result.success()) {
|
||||
return json_quote_style::NONE;
|
||||
}
|
||||
return result.tags.count("sq") && !result.tags["sq"].empty()
|
||||
? json_quote_style::SINGLE_QUOTES
|
||||
: json_quote_style::DOUBLE_QUOTES;
|
||||
return result.result.success();
|
||||
};
|
||||
|
||||
auto fun_quote = in_json_haystack(fun_name_needle);
|
||||
auto arg_quote = in_json_haystack(arg_name_needle);
|
||||
|
||||
if (fun_quote != json_quote_style::NONE) {
|
||||
if (fun_quote) {
|
||||
// no need to check further, we're in JSON land
|
||||
format.mode = tool_format::JSON_NATIVE;
|
||||
format.uses_python_dicts = (fun_quote == json_quote_style::SINGLE_QUOTES);
|
||||
} else if (arg_quote != json_quote_style::NONE) {
|
||||
} else if (arg_quote) {
|
||||
format.mode = tool_format::TAG_WITH_JSON;
|
||||
format.uses_python_dicts = (arg_quote == json_quote_style::SINGLE_QUOTES);
|
||||
} else {
|
||||
format.mode = tool_format::TAG_WITH_TAGGED;
|
||||
}
|
||||
|
||||
@@ -229,6 +229,20 @@ void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena,
|
||||
result.tool_calls.push_back(pending_tool_call.value());
|
||||
pending_tool_call.reset();
|
||||
}
|
||||
|
||||
// Discard whitespace-only reasoning content (e.g. from <think></think> prefill)
|
||||
if (!result.reasoning_content.empty()) {
|
||||
bool all_whitespace = true;
|
||||
for (char c : result.reasoning_content) {
|
||||
if (c != ' ' && c != '\n' && c != '\r' && c != '\t') {
|
||||
all_whitespace = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (all_whitespace) {
|
||||
result.reasoning_content.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
|
||||
|
||||
322
common/chat.cpp
322
common/chat.cpp
@@ -1,5 +1,6 @@
|
||||
#include "chat.h"
|
||||
|
||||
#include "chat-auto-parser-helpers.h"
|
||||
#include "chat-auto-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "common.h"
|
||||
@@ -22,6 +23,7 @@
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
@@ -760,7 +762,7 @@ static void foreach_parameter(const json &
|
||||
|
||||
std::string common_chat_template_direct_apply(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::optional<json> & messages_override,
|
||||
const std::optional<json> & tools_override,
|
||||
const std::optional<json> & additional_context) {
|
||||
@@ -811,7 +813,7 @@ std::string common_chat_template_direct_apply(
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs) {
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
// Build up messages to follow the format: https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512/blob/main/chat_template.jinja
|
||||
@@ -876,8 +878,8 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
||||
// Response format parser
|
||||
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
|
||||
// Ministral wants to emit json surrounded by code fences
|
||||
return reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema))
|
||||
<< "```";
|
||||
return wrap_for_generation_prompt(p, reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```",
|
||||
inputs, "[THINK]");
|
||||
}
|
||||
|
||||
// Tool call parser
|
||||
@@ -897,12 +899,13 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
||||
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
|
||||
auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls));
|
||||
|
||||
return reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls;
|
||||
return wrap_for_generation_prompt(p, reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls,
|
||||
inputs, "[THINK]");
|
||||
}
|
||||
|
||||
// Content only parser
|
||||
include_grammar = false;
|
||||
return reasoning << p.content(p.rest());
|
||||
return wrap_for_generation_prompt(p, reasoning << p.content(p.rest()), inputs, "[THINK]");
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
@@ -928,7 +931,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs) {
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
// Copy reasoning to the "thinking" field as expected by the gpt-oss template
|
||||
@@ -936,7 +939,9 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
for (auto msg : inputs.messages) {
|
||||
if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) {
|
||||
msg["thinking"] = msg.at("reasoning_content");
|
||||
msg.erase("content");
|
||||
if (msg.contains("tool_calls") && msg.at("tool_calls").is_array() && !msg.at("tool_calls").empty()) {
|
||||
msg.erase("content");
|
||||
}
|
||||
}
|
||||
adjusted_messages.push_back(msg);
|
||||
}
|
||||
@@ -986,7 +991,8 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
p.literal("<|channel|>final") + constraint + p.literal("<|message|>") +
|
||||
p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
|
||||
|
||||
return response_format | (analysis + p.zero_or_more(start + analysis) + start + response_format);
|
||||
return wrap_for_generation_prompt(p, response_format | (analysis + p.zero_or_more(start + analysis) + start + response_format),
|
||||
inputs, "<|channel|>");
|
||||
}
|
||||
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
@@ -1018,10 +1024,12 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
return tool_call | ( any + p.zero_or_more(start + any) + start + tool_call);
|
||||
}
|
||||
|
||||
return tool_call | final_msg | (any + p.zero_or_more(start + any) + start + (tool_call | final_msg));
|
||||
return wrap_for_generation_prompt(p, tool_call | final_msg | (any + p.zero_or_more(start + any) + start + (tool_call | final_msg)),
|
||||
inputs, "<|channel|>");
|
||||
}
|
||||
|
||||
return final_msg | (any + p.zero_or_more(start + any) + start + final_msg);
|
||||
return wrap_for_generation_prompt(p, final_msg | (any + p.zero_or_more(start + any) + start + final_msg),
|
||||
inputs, "<|channel|>");
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
@@ -1049,7 +1057,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
|
||||
// Functionary v3.2 - uses recipient-based format: >>>recipient\n{content}
|
||||
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs) {
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
@@ -1070,13 +1078,13 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
||||
// Build content parser for >>>all\n{content}
|
||||
// When tools are present, content stops before the next ">>>" (tool call)
|
||||
// When no tools, content goes until end
|
||||
auto content_until_tool = p.literal(">>>all\n") + p.content(p.until(">>>"));
|
||||
auto content_until_end = p.literal(">>>all\n") + p.content(p.rest());
|
||||
auto content_until_tool = p.literal("all\n") + p.content(p.until(">>>"));
|
||||
auto content_until_end = p.literal("all\n") + p.content(p.rest());
|
||||
|
||||
// If no tools or tool_choice is NONE, just parse content
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
// When no tools, just match the prefix and capture everything after
|
||||
return content_until_end + p.end();
|
||||
return wrap_for_generation_prompt(p, content_until_end + p.end(), inputs);
|
||||
}
|
||||
|
||||
// Build tool call parsers for each available function
|
||||
@@ -1088,7 +1096,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
||||
|
||||
// Tool format: >>>function_name\n{json_args}
|
||||
auto tool_parser = p.tool(
|
||||
p.tool_open(p.literal(">>>") + p.tool_name(p.literal(name)) + p.literal("\n")) +
|
||||
p.tool_open(p.tool_name(p.literal(name)) + p.literal("\n")) +
|
||||
p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema))
|
||||
);
|
||||
|
||||
@@ -1099,17 +1107,20 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
||||
auto tools_only = p.trigger_rule("tools", p.one_or_more(tool_choice));
|
||||
auto content_and_tools = content_until_tool + tools_only;
|
||||
|
||||
auto ret = p.eps();
|
||||
if (inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED) {
|
||||
if (inputs.parallel_tool_calls) {
|
||||
return p.choice({ content_and_tools, tools_only }) + p.end();
|
||||
ret = p.choice({ content_and_tools, tools_only }) + p.end();
|
||||
} else {
|
||||
ret = p.choice({ content_until_tool + tool_choice, tools_only }) + p.end();
|
||||
}
|
||||
return p.choice({ content_until_tool + tool_choice, tools_only }) + p.end();
|
||||
} else if (inputs.parallel_tool_calls) {
|
||||
ret = p.choice({ content_and_tools, content_only, tools_only }) + p.end();
|
||||
} else {
|
||||
auto content_and_tool = content_until_tool + tool_choice;
|
||||
ret = p.choice({ content_and_tool, content_only, tool_choice }) + p.end();
|
||||
}
|
||||
if (inputs.parallel_tool_calls) {
|
||||
return p.choice({ content_and_tools, content_only, tools_only }) + p.end();
|
||||
}
|
||||
auto content_and_tool = content_until_tool + tool_choice;
|
||||
return p.choice({ content_and_tool, content_only, tool_choice }) + p.end();
|
||||
return wrap_for_generation_prompt(p, ret, inputs);
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
@@ -1139,14 +1150,12 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
||||
// Kimi K2 Thinking - uses unique tool call ID format: functions.<name>:<index>
|
||||
// The ID contains both the function name and an incrementing counter
|
||||
static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs) {
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "<think>";
|
||||
data.thinking_end_tag = "</think>";
|
||||
data.preserved_tokens = {
|
||||
"<|tool_calls_section_begin|>",
|
||||
"<|tool_calls_section_end|>",
|
||||
@@ -1161,6 +1170,18 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
|
||||
const std::string SECTION_BEGIN = "<|tool_calls_section_begin|>";
|
||||
const std::string SECTION_END = "<|tool_calls_section_end|>";
|
||||
const std::string CALL_BEGIN = "<|tool_call_begin|>";
|
||||
const std::string ARGS_BEGIN = "<|tool_call_argument_begin|>";
|
||||
const std::string CALL_END = "<|tool_call_end|>";
|
||||
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
|
||||
data.thinking_start_tag = THINK_START;
|
||||
data.thinking_end_tag = THINK_END;
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
// Kimi K2 Thinking format:
|
||||
// - Reasoning: <think>{reasoning}</think>
|
||||
@@ -1172,16 +1193,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
||||
// <|tool_calls_section_end|>
|
||||
// The ID format is: functions.<function_name>:<counter> where counter is 0, 1, 2, ...
|
||||
|
||||
// Tool call markers
|
||||
const std::string SECTION_BEGIN = "<|tool_calls_section_begin|>";
|
||||
const std::string SECTION_END = "<|tool_calls_section_end|>";
|
||||
const std::string CALL_BEGIN = "<|tool_call_begin|>";
|
||||
const std::string ARGS_BEGIN = "<|tool_call_argument_begin|>";
|
||||
const std::string CALL_END = "<|tool_call_end|>";
|
||||
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
|
||||
// Tool call markers
|
||||
auto end = p.end();
|
||||
|
||||
// Note: this model is CRAZY. It can diverge from its supposed tool calling pattern in so many ways it's not funny.
|
||||
@@ -1193,7 +1205,8 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
||||
|
||||
// Content only parser (no tools)
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return reasoning + p.content(p.rest()) + end;
|
||||
return wrap_for_generation_prompt(p, reasoning + p.content(p.rest()) + end,
|
||||
inputs, THINK_START);
|
||||
}
|
||||
|
||||
// Build tool call parsers for each available function
|
||||
@@ -1229,7 +1242,8 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
||||
|
||||
auto content_before_tools = p.content(p.until_one_of({ SECTION_BEGIN, CALL_BEGIN }));
|
||||
|
||||
return reasoning + content_before_tools + tool_calls + end;
|
||||
return wrap_for_generation_prompt(p, reasoning + content_before_tools + tool_calls + end,
|
||||
inputs, THINK_START);
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
@@ -1259,7 +1273,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
||||
// - Tool calls: <|tool_call_start|>[function_name(arg1="value1", arg2="value2")]<|tool_call_end|>
|
||||
// Tool calls can appear multiple times (parallel tool calls)
|
||||
static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs) {
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
@@ -1278,13 +1292,15 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
|
||||
|
||||
const std::string TOOL_CALL_START = "<|tool_call_start|>";
|
||||
const std::string TOOL_CALL_END = "<|tool_call_end|>";
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
|
||||
data.thinking_start_tag = THINK_START;
|
||||
data.thinking_end_tag = THINK_END;
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto end = p.end();
|
||||
|
||||
auto reasoning = p.eps();
|
||||
@@ -1293,7 +1309,8 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
||||
}
|
||||
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return reasoning + p.content(p.rest()) + end;
|
||||
return wrap_for_generation_prompt(p, reasoning + p.content(p.rest()) + end, inputs,
|
||||
THINK_START);
|
||||
}
|
||||
|
||||
auto tool_calls = p.rule("tool-calls",
|
||||
@@ -1305,7 +1322,8 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
||||
|
||||
auto content = p.content(p.until(TOOL_CALL_START));
|
||||
|
||||
return reasoning + content + tool_calls + end;
|
||||
return wrap_for_generation_prompt(p, reasoning + content + tool_calls + end, inputs,
|
||||
THINK_START);
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
@@ -1331,7 +1349,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
||||
|
||||
static common_chat_params common_chat_params_init_gigachat_v3(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs) {
|
||||
const autoparser::generation_params & inputs) {
|
||||
|
||||
common_chat_params data;
|
||||
|
||||
@@ -1345,9 +1363,10 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
auto tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n";
|
||||
const auto *tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n";
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto ret = p.eps();
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
// Build a choice of all available tools
|
||||
auto tool_choice = p.choice();
|
||||
@@ -1370,13 +1389,14 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
||||
auto tool_call = p.rule("tool-call", p.literal(tool_call_start_prefix) + tool_choice);
|
||||
auto tool_calls = p.trigger_rule("tool-call-root", p.repeat(tool_call, /* min = */ min_calls, /* max = */ max_calls));
|
||||
|
||||
return p.content(p.until("<|message_sep|>\n\n")) << tool_calls;
|
||||
ret = p.content(p.until("<|message_sep|>\n\n")) << tool_calls;
|
||||
} else {
|
||||
// Content only parser
|
||||
include_grammar = false;
|
||||
ret = p.content(p.rest());
|
||||
}
|
||||
|
||||
// Content only parser
|
||||
include_grammar = false;
|
||||
return p.content(p.rest());
|
||||
|
||||
return wrap_for_generation_prompt(p, ret, inputs);
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
@@ -1471,87 +1491,10 @@ static json common_chat_extra_context() {
|
||||
return ctx;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs) {
|
||||
autoparser::templates_params params;
|
||||
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
|
||||
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
|
||||
? *tmpls->template_tool_use
|
||||
: *tmpls->template_default;
|
||||
const auto & src = tmpl.source();
|
||||
const auto & caps = tmpl.original_caps();
|
||||
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.reasoning_format = inputs.reasoning_format;
|
||||
params.enable_thinking = inputs.enable_thinking;
|
||||
params.grammar = inputs.grammar;
|
||||
params.now = inputs.now;
|
||||
params.add_bos = tmpls->add_bos;
|
||||
params.add_eos = tmpls->add_eos;
|
||||
|
||||
if (src.find("<|channel|>") == std::string::npos) {
|
||||
// map developer to system for all models except for GPT-OSS
|
||||
workaround::map_developer_role_to_system(params.messages);
|
||||
}
|
||||
|
||||
if (!tmpl.original_caps().supports_system_role) {
|
||||
workaround::system_message_not_supported(params.messages);
|
||||
}
|
||||
|
||||
if (tmpl.original_caps().supports_tool_calls) {
|
||||
// some templates will require the content field in tool call messages
|
||||
// to still be non-null, this puts an empty string everywhere where the
|
||||
// content field is null
|
||||
workaround::requires_non_null_content(params.messages);
|
||||
}
|
||||
|
||||
if (tmpl.original_caps().supports_object_arguments) {
|
||||
workaround::func_args_not_string(params.messages);
|
||||
}
|
||||
|
||||
params.extra_context = common_chat_extra_context();
|
||||
for (auto el : inputs.chat_template_kwargs) {
|
||||
params.extra_context[el.first] = json::parse(el.second);
|
||||
}
|
||||
|
||||
if (!inputs.json_schema.empty()) {
|
||||
params.json_schema = json::parse(inputs.json_schema);
|
||||
}
|
||||
|
||||
// if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
|
||||
// LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
|
||||
// params.parallel_tool_calls = false;
|
||||
// } else {
|
||||
params.parallel_tool_calls = inputs.parallel_tool_calls;
|
||||
//}
|
||||
|
||||
if (params.tools.is_array()) {
|
||||
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
|
||||
throw std::runtime_error("Cannot specify grammar with tools");
|
||||
}
|
||||
if (caps.supports_tool_calls && !caps.supports_tools) {
|
||||
LOG_WRN(
|
||||
"Template supports tool calls but does not natively describe tools. The fallback behaviour used may "
|
||||
"produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
|
||||
}
|
||||
}
|
||||
|
||||
if (inputs.force_pure_content) {
|
||||
LOG_WRN("Forcing pure content template, will not render reasoning or tools separately.");
|
||||
// Create the result structure
|
||||
common_chat_params data;
|
||||
auto params_copy = params;
|
||||
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, params_copy);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
auto parser = build_chat_peg_parser([](common_chat_peg_builder &p) {
|
||||
return p.content(p.rest());
|
||||
});
|
||||
data.parser = parser.save();
|
||||
return data;
|
||||
}
|
||||
|
||||
static std::optional<common_chat_params> try_specialized_template(
|
||||
const common_chat_template & tmpl,
|
||||
const std::string & src,
|
||||
const autoparser::generation_params & params) {
|
||||
// Ministral/Mistral Large 3 - uses special reasoning structure fixes, can't use autoparser
|
||||
// Note: Mistral Small 3.2 uses [CALL_ID] which Ministral doesn't have, so we can distinguish them
|
||||
if (src.find("[SYSTEM_PROMPT]") != std::string::npos && src.find("[TOOL_CALLS]") != std::string::npos &&
|
||||
@@ -1592,14 +1535,105 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
// GigaChatV3 format detection
|
||||
if (src.find("<|role_sep|>") != std::string::npos &&
|
||||
src.find("<|message_sep|>") != std::string::npos &&
|
||||
src.find("<|function_call|>") == std::string::npos
|
||||
) {
|
||||
src.find("<|function_call|>") == std::string::npos) {
|
||||
LOG_DBG("Using specialized template: GigaChatV3\n");
|
||||
return common_chat_params_init_gigachat_v3(tmpl, params);
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs) {
|
||||
autoparser::generation_params params;
|
||||
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
|
||||
const auto & tmpl =
|
||||
params.tools.is_array() && tmpls->template_tool_use ? *tmpls->template_tool_use : *tmpls->template_default;
|
||||
const auto & src = tmpl.source();
|
||||
const auto & caps = tmpl.original_caps();
|
||||
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.reasoning_format = inputs.reasoning_format;
|
||||
params.enable_thinking = inputs.enable_thinking;
|
||||
params.grammar = inputs.grammar;
|
||||
params.now = inputs.now;
|
||||
params.add_bos = tmpls->add_bos;
|
||||
params.add_eos = tmpls->add_eos;
|
||||
|
||||
if (src.find("<|channel|>") == std::string::npos) {
|
||||
// map developer to system for all models except for GPT-OSS
|
||||
workaround::map_developer_role_to_system(params.messages);
|
||||
}
|
||||
|
||||
if (!tmpl.original_caps().supports_system_role) {
|
||||
workaround::system_message_not_supported(params.messages);
|
||||
}
|
||||
|
||||
if (tmpl.original_caps().supports_tool_calls) {
|
||||
// some templates will require the content field in tool call messages
|
||||
// to still be non-null, this puts an empty string everywhere where the
|
||||
// content field is null
|
||||
workaround::requires_non_null_content(params.messages);
|
||||
}
|
||||
|
||||
if (tmpl.original_caps().supports_object_arguments) {
|
||||
workaround::func_args_not_string(params.messages);
|
||||
}
|
||||
|
||||
params.add_generation_prompt = false;
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply(tmpl, params);
|
||||
params.add_generation_prompt = true;
|
||||
std::string gen_prompt = common_chat_template_direct_apply(tmpl, params);
|
||||
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
|
||||
params.generation_prompt = diff.right;
|
||||
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
|
||||
params.extra_context = common_chat_extra_context();
|
||||
for (auto el : inputs.chat_template_kwargs) {
|
||||
params.extra_context[el.first] = json::parse(el.second);
|
||||
}
|
||||
|
||||
if (!inputs.json_schema.empty()) {
|
||||
params.json_schema = json::parse(inputs.json_schema);
|
||||
}
|
||||
|
||||
params.parallel_tool_calls = inputs.parallel_tool_calls;
|
||||
|
||||
if (params.tools.is_array()) {
|
||||
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
|
||||
throw std::runtime_error("Cannot specify grammar with tools");
|
||||
}
|
||||
if (caps.supports_tool_calls && !caps.supports_tools) {
|
||||
LOG_WRN(
|
||||
"Template supports tool calls but does not natively describe tools. The fallback behaviour used may "
|
||||
"produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
|
||||
}
|
||||
}
|
||||
|
||||
if (inputs.force_pure_content) {
|
||||
LOG_WRN("Forcing pure content template, will not render reasoning or tools separately.");
|
||||
// Create the result structure
|
||||
common_chat_params data;
|
||||
auto params_copy = params;
|
||||
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, params_copy);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.generation_prompt = params.generation_prompt;
|
||||
auto parser = build_chat_peg_parser([¶ms](common_chat_peg_builder &p) {
|
||||
return wrap_for_generation_prompt(p, p.content(p.rest()), params);
|
||||
});
|
||||
data.parser = parser.save();
|
||||
return data;
|
||||
}
|
||||
|
||||
if (auto result = try_specialized_template(tmpl, src, params)) {
|
||||
result->generation_prompt = params.generation_prompt;
|
||||
return *result;
|
||||
}
|
||||
|
||||
try {
|
||||
LOG_DBG("Using differential autoparser\n");
|
||||
LOG_DBG("%s: using differential autoparser\n", __func__);
|
||||
struct autoparser::autoparser autoparser;
|
||||
autoparser.analyze_template(tmpl);
|
||||
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
|
||||
@@ -1607,13 +1641,11 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
if (auto_params.supports_thinking) {
|
||||
auto_params.thinking_start_tag = autoparser.reasoning.start;
|
||||
auto_params.thinking_end_tag = autoparser.reasoning.end;
|
||||
// FORCED_OPEN and FORCED_CLOSED both put <think> in the generation prompt
|
||||
// (FORCED_CLOSED forces empty <think></think> when thinking is disabled,
|
||||
// but forces <think> open when thinking is enabled)
|
||||
auto_params.thinking_forced_open =
|
||||
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_OPEN ||
|
||||
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_CLOSED;
|
||||
}
|
||||
auto_params.generation_prompt = params.generation_prompt;
|
||||
common_peg_arena arena;
|
||||
arena.load(auto_params.parser);
|
||||
LOG_DBG("%s: generated parser:\n%s\n\nparser generation prompt: %s\n", __func__, arena.dump(arena.root()).c_str(), auto_params.generation_prompt.c_str());
|
||||
return auto_params;
|
||||
} catch (const std::exception & e) {
|
||||
throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what());
|
||||
@@ -1711,14 +1743,18 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
||||
LOG_DBG("No parser definition detected, assuming pure content parser.");
|
||||
}
|
||||
|
||||
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), input.c_str());
|
||||
const std::string effective_input = params.generation_prompt.empty()
|
||||
? input
|
||||
: params.generation_prompt + input;
|
||||
|
||||
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), effective_input.c_str());
|
||||
|
||||
common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_LENIENT;
|
||||
if (params.debug) {
|
||||
flags |= COMMON_PEG_PARSE_FLAG_DEBUG;
|
||||
}
|
||||
|
||||
common_peg_parse_context ctx(input, flags);
|
||||
common_peg_parse_context ctx(effective_input, flags);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
if (result.fail()) {
|
||||
@@ -1738,7 +1774,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
||||
return msg;
|
||||
}
|
||||
throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end) + ": " +
|
||||
input.substr(result.end));
|
||||
effective_input.substr(result.end));
|
||||
}
|
||||
|
||||
common_chat_msg msg;
|
||||
|
||||
@@ -24,7 +24,7 @@ using json = nlohmann::ordered_json;
|
||||
struct common_chat_templates;
|
||||
|
||||
namespace autoparser {
|
||||
struct templates_params;
|
||||
struct generation_params;
|
||||
} // namespace autoparser
|
||||
|
||||
struct common_chat_tool_call {
|
||||
@@ -212,7 +212,7 @@ struct common_chat_params {
|
||||
std::string prompt;
|
||||
std::string grammar;
|
||||
bool grammar_lazy = false;
|
||||
bool thinking_forced_open = false;
|
||||
std::string generation_prompt;
|
||||
bool supports_thinking = false;
|
||||
std::string thinking_start_tag; // e.g., "<think>"
|
||||
std::string thinking_end_tag; // e.g., "</think>"
|
||||
@@ -229,14 +229,14 @@ struct common_chat_parser_params {
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
|
||||
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
|
||||
bool reasoning_in_content = false;
|
||||
bool thinking_forced_open = false;
|
||||
std::string generation_prompt;
|
||||
bool parse_tool_calls = true;
|
||||
bool debug = false; // Enable debug output for PEG parser
|
||||
common_peg_arena parser = {};
|
||||
common_chat_parser_params() = default;
|
||||
common_chat_parser_params(const common_chat_params & chat_params) {
|
||||
format = chat_params.format;
|
||||
thinking_forced_open = chat_params.thinking_forced_open;
|
||||
format = chat_params.format;
|
||||
generation_prompt = chat_params.generation_prompt;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -302,7 +302,7 @@ std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_tem
|
||||
|
||||
std::string common_chat_template_direct_apply(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::optional<json> & messages_override = std::nullopt,
|
||||
const std::optional<json> & tools_override = std::nullopt,
|
||||
const std::optional<json> & additional_context = std::nullopt);
|
||||
|
||||
@@ -3,12 +3,14 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml-opt.h"
|
||||
#include "ggml.h"
|
||||
#include "llama-cpp.h"
|
||||
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
@@ -178,6 +180,43 @@ enum common_speculative_type {
|
||||
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
|
||||
};
|
||||
|
||||
// Grammar type enumeration
|
||||
enum common_grammar_type {
|
||||
COMMON_GRAMMAR_TYPE_NONE, // no grammar set
|
||||
COMMON_GRAMMAR_TYPE_USER, // user-provided GBNF (--grammar / "grammar" API field)
|
||||
COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, // auto-generated from JSON schema (--json-schema / "json_schema" API field)
|
||||
COMMON_GRAMMAR_TYPE_TOOL_CALLS, // auto-generated by chat template parser for function calling
|
||||
};
|
||||
|
||||
// Grammar variant struct with type and grammar string
|
||||
struct common_grammar {
|
||||
common_grammar_type type = COMMON_GRAMMAR_TYPE_NONE;
|
||||
std::string grammar;
|
||||
|
||||
// Default constructor - no grammar
|
||||
common_grammar() = default;
|
||||
|
||||
// Constructor with type and grammar string
|
||||
common_grammar(common_grammar_type t, std::string g) : type(t), grammar(std::move(g)) {
|
||||
GGML_ASSERT(type != COMMON_GRAMMAR_TYPE_NONE || !grammar.empty());
|
||||
}
|
||||
|
||||
// Check if a grammar is set
|
||||
bool empty() const { return type == COMMON_GRAMMAR_TYPE_NONE || grammar.empty(); }
|
||||
};
|
||||
|
||||
// Returns the raw grammar string, or empty string if no grammar is set.
|
||||
inline const std::string & common_grammar_value(const common_grammar & g) {
|
||||
return g.grammar;
|
||||
}
|
||||
|
||||
// Returns true when the generation_prompt should be prefilled into the grammar sampler.
|
||||
// Only output-format and tool-call grammars need prefill; user-supplied grammars must not be prefilled.
|
||||
inline bool common_grammar_needs_prefill(const common_grammar & g) {
|
||||
return g.type == COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT
|
||||
|| g.type == COMMON_GRAMMAR_TYPE_TOOL_CALLS;
|
||||
}
|
||||
|
||||
// sampling parameters
|
||||
struct common_params_sampling {
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||
@@ -228,7 +267,7 @@ struct common_params_sampling {
|
||||
COMMON_SAMPLER_TYPE_TEMPERATURE,
|
||||
};
|
||||
|
||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||
common_grammar grammar; // optional grammar constraint (user / output-format / tool-calls)
|
||||
bool grammar_lazy = false;
|
||||
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
|
||||
std::set<llama_token> preserved_tokens;
|
||||
@@ -236,10 +275,15 @@ struct common_params_sampling {
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||
|
||||
// The assistant generation prompt already prefilled into the prompt.
|
||||
// Fed to the grammar sampler (to advance past pre-existing tokens) and used
|
||||
// to determine the reasoning budget sampler's initial state.
|
||||
// Only applied when the grammar is of output-format or tool-calls type.
|
||||
std::string generation_prompt;
|
||||
|
||||
// reasoning budget sampler parameters
|
||||
// these are populated by the server/CLI based on chat template params
|
||||
int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget
|
||||
bool reasoning_budget_activate_immediately = false;
|
||||
std::vector<llama_token> reasoning_budget_start; // start tag token sequence
|
||||
std::vector<llama_token> reasoning_budget_end; // end tag token sequence
|
||||
std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag)
|
||||
|
||||
@@ -53,6 +53,13 @@ private:
|
||||
return tokens[current + offset];
|
||||
}
|
||||
|
||||
const token & next() {
|
||||
if (current >= tokens.size()) {
|
||||
throw parser_exception("Parser Error: Unexpected EOF", source, tokens.empty() ? 0 : tokens.back().pos);
|
||||
}
|
||||
return tokens[current++];
|
||||
}
|
||||
|
||||
token expect(token::type type, const std::string& error) {
|
||||
const auto & t = peek();
|
||||
if (t.t != type) {
|
||||
@@ -90,9 +97,9 @@ private:
|
||||
size_t start_pos = current;
|
||||
switch (peek().t) {
|
||||
case token::comment:
|
||||
return mk_stmt<comment_statement>(start_pos, tokens[current++].value);
|
||||
return mk_stmt<comment_statement>(start_pos, next().value);
|
||||
case token::text:
|
||||
return mk_stmt<string_literal>(start_pos, tokens[current++].value);
|
||||
return mk_stmt<string_literal>(start_pos, next().value);
|
||||
case token::open_statement:
|
||||
return parse_jinja_statement();
|
||||
case token::open_expression:
|
||||
@@ -119,8 +126,7 @@ private:
|
||||
}
|
||||
|
||||
size_t start_pos = current;
|
||||
std::string name = peek().value;
|
||||
current++; // consume identifier
|
||||
std::string name = next().value;
|
||||
|
||||
statement_ptr result;
|
||||
if (name == "set") {
|
||||
@@ -202,7 +208,7 @@ private:
|
||||
// Ignore generation blocks (transformers-specific)
|
||||
// See https://github.com/huggingface/transformers/pull/30650 for more information.
|
||||
result = mk_stmt<noop_statement>(start_pos);
|
||||
current++;
|
||||
++current;
|
||||
|
||||
} else {
|
||||
throw std::runtime_error("Unknown statement: " + name);
|
||||
@@ -217,7 +223,7 @@ private:
|
||||
statements body;
|
||||
|
||||
if (is(token::equals)) {
|
||||
current++;
|
||||
++current;
|
||||
value = parse_expression_sequence();
|
||||
} else {
|
||||
// parsing multiline set here
|
||||
@@ -280,7 +286,7 @@ private:
|
||||
exprs.push_back(primary ? parse_primary_expression() : parse_expression());
|
||||
bool is_tuple = is(token::comma);
|
||||
while (is(token::comma)) {
|
||||
current++; // consume comma
|
||||
++current; // consume comma
|
||||
exprs.push_back(primary ? parse_primary_expression() : parse_expression());
|
||||
}
|
||||
return is_tuple ? mk_stmt<tuple_literal>(start_pos, std::move(exprs)) : std::move(exprs[0]);
|
||||
@@ -290,7 +296,7 @@ private:
|
||||
// e.g., `message` in `for message in messages`
|
||||
auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple
|
||||
if (!is_identifier("in")) throw std::runtime_error("Expected 'in'");
|
||||
current++;
|
||||
++current; // consume 'in'
|
||||
|
||||
// `messages` in `for message in messages`
|
||||
auto iterable = parse_expression();
|
||||
@@ -305,7 +311,8 @@ private:
|
||||
}
|
||||
|
||||
if (is_statement({"else"})) {
|
||||
current += 2;
|
||||
++current; // consume {%
|
||||
++current; // consume 'else'
|
||||
expect(token::close_statement, "Expected %}");
|
||||
while (!is_statement({"endfor"})) {
|
||||
alternate.push_back(parse_any());
|
||||
@@ -347,7 +354,7 @@ private:
|
||||
auto left = parse_logical_and_expression();
|
||||
while (is_identifier("or")) {
|
||||
size_t start_pos = current;
|
||||
token op = tokens[current++];
|
||||
token op = next();
|
||||
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_and_expression());
|
||||
}
|
||||
return left;
|
||||
@@ -357,7 +364,7 @@ private:
|
||||
auto left = parse_logical_negation_expression();
|
||||
while (is_identifier("and")) {
|
||||
size_t start_pos = current;
|
||||
auto op = tokens[current++];
|
||||
auto op = next();
|
||||
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_negation_expression());
|
||||
}
|
||||
return left;
|
||||
@@ -367,7 +374,7 @@ private:
|
||||
// Try parse unary operators
|
||||
if (is_identifier("not")) {
|
||||
size_t start_pos = current;
|
||||
auto op = tokens[current++];
|
||||
auto op = next();
|
||||
return mk_stmt<unary_expression>(start_pos, op, parse_logical_negation_expression());
|
||||
}
|
||||
return parse_comparison_expression();
|
||||
@@ -382,11 +389,12 @@ private:
|
||||
size_t start_pos = current;
|
||||
if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") {
|
||||
op = {token::identifier, "not in", tokens[current].pos};
|
||||
current += 2;
|
||||
++current; // consume 'not'
|
||||
++current; // consume 'in'
|
||||
} else if (is_identifier("in")) {
|
||||
op = tokens[current++];
|
||||
op = next();
|
||||
} else if (is(token::comparison_binary_operator)) {
|
||||
op = tokens[current++];
|
||||
op = next();
|
||||
} else break;
|
||||
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_additive_expression());
|
||||
}
|
||||
@@ -397,7 +405,7 @@ private:
|
||||
auto left = parse_multiplicative_expression();
|
||||
while (is(token::additive_binary_operator)) {
|
||||
size_t start_pos = current;
|
||||
auto op = tokens[current++];
|
||||
auto op = next();
|
||||
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_multiplicative_expression());
|
||||
}
|
||||
return left;
|
||||
@@ -407,7 +415,7 @@ private:
|
||||
auto left = parse_test_expression();
|
||||
while (is(token::multiplicative_binary_operator)) {
|
||||
size_t start_pos = current;
|
||||
auto op = tokens[current++];
|
||||
auto op = next();
|
||||
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_test_expression());
|
||||
}
|
||||
return left;
|
||||
@@ -417,9 +425,9 @@ private:
|
||||
auto operand = parse_filter_expression();
|
||||
while (is_identifier("is")) {
|
||||
size_t start_pos = current;
|
||||
current++;
|
||||
++current; // consume 'is'
|
||||
bool negate = false;
|
||||
if (is_identifier("not")) { current++; negate = true; }
|
||||
if (is_identifier("not")) { ++current; negate = true; }
|
||||
auto test_id = parse_primary_expression();
|
||||
// FIXME: tests can also be expressed like this: if x is eq 3
|
||||
if (is(token::open_paren)) test_id = parse_call_expression(std::move(test_id));
|
||||
@@ -432,7 +440,7 @@ private:
|
||||
auto operand = parse_call_member_expression();
|
||||
while (is(token::pipe)) {
|
||||
size_t start_pos = current;
|
||||
current++;
|
||||
++current; // consume pipe
|
||||
auto filter = parse_primary_expression();
|
||||
if (is(token::open_paren)) filter = parse_call_expression(std::move(filter));
|
||||
operand = mk_stmt<filter_expression>(start_pos, std::move(operand), std::move(filter));
|
||||
@@ -490,7 +498,7 @@ private:
|
||||
statement_ptr parse_member_expression(statement_ptr object) {
|
||||
size_t start_pos = current;
|
||||
while (is(token::dot) || is(token::open_square_bracket)) {
|
||||
auto op = tokens[current++];
|
||||
auto op = next();
|
||||
bool computed = op.t == token::open_square_bracket;
|
||||
statement_ptr prop;
|
||||
if (computed) {
|
||||
@@ -536,7 +544,7 @@ private:
|
||||
|
||||
statement_ptr parse_primary_expression() {
|
||||
size_t start_pos = current;
|
||||
auto t = tokens[current++];
|
||||
auto t = next();
|
||||
switch (t.t) {
|
||||
case token::numeric_literal:
|
||||
if (t.value.find('.') != std::string::npos) {
|
||||
@@ -547,7 +555,7 @@ private:
|
||||
case token::string_literal: {
|
||||
std::string val = t.value;
|
||||
while (is(token::string_literal)) {
|
||||
val += tokens[current++].value;
|
||||
val += next().value;
|
||||
}
|
||||
return mk_stmt<string_literal>(start_pos, val);
|
||||
}
|
||||
@@ -562,9 +570,9 @@ private:
|
||||
statements vals;
|
||||
while (!is(token::close_square_bracket)) {
|
||||
vals.push_back(parse_expression());
|
||||
if (is(token::comma)) current++;
|
||||
if (is(token::comma)) ++current;
|
||||
}
|
||||
current++;
|
||||
++current;
|
||||
return mk_stmt<array_literal>(start_pos, std::move(vals));
|
||||
}
|
||||
case token::open_curly_bracket: {
|
||||
@@ -573,9 +581,9 @@ private:
|
||||
auto key = parse_expression();
|
||||
expect(token::colon, "Expected :");
|
||||
pairs.push_back({std::move(key), parse_expression()});
|
||||
if (is(token::comma)) current++;
|
||||
if (is(token::comma)) ++current;
|
||||
}
|
||||
current++;
|
||||
++current;
|
||||
return mk_stmt<object_literal>(start_pos, std::move(pairs));
|
||||
}
|
||||
default:
|
||||
|
||||
@@ -451,7 +451,7 @@ struct value_array_t : public value_t {
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), value_equivalence());
|
||||
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), other.val_arr.end(), value_equivalence());
|
||||
}
|
||||
};
|
||||
using value_array = std::shared_ptr<value_array_t>;
|
||||
@@ -587,7 +587,7 @@ struct value_object_t : public value_t {
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), value_equivalence());
|
||||
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), other.val_obj.end(), value_equivalence());
|
||||
}
|
||||
};
|
||||
using value_object = std::shared_ptr<value_object_t>;
|
||||
|
||||
@@ -163,9 +163,15 @@ static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
|
||||
ctx->force_pos = 0;
|
||||
}
|
||||
|
||||
// forward declaration for use in clone
|
||||
static struct llama_sampler * common_reasoning_budget_init_state(
|
||||
const struct llama_vocab * vocab, const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens, const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget, common_reasoning_budget_state initial_state);
|
||||
|
||||
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
|
||||
return common_reasoning_budget_init(
|
||||
return common_reasoning_budget_init_state(
|
||||
ctx->vocab,
|
||||
ctx->start_matcher.tokens,
|
||||
ctx->end_matcher.tokens,
|
||||
@@ -191,13 +197,13 @@ static struct llama_sampler_i common_reasoning_budget_i = {
|
||||
/* .backend_set_input = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
common_reasoning_budget_state initial_state) {
|
||||
static struct llama_sampler * common_reasoning_budget_init_state(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
common_reasoning_budget_state initial_state) {
|
||||
// promote COUNTING with budget <= 0 to FORCING
|
||||
if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
|
||||
initial_state = REASONING_BUDGET_FORCING;
|
||||
@@ -217,3 +223,41 @@ struct llama_sampler * common_reasoning_budget_init(
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
const std::vector<llama_token> & prefill_tokens) {
|
||||
// Determine initial state from prefill: COUNTING if the prefill begins with
|
||||
// the start sequence but does not also contain the end sequence after it.
|
||||
common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE;
|
||||
if (!prefill_tokens.empty() && !start_tokens.empty() &&
|
||||
prefill_tokens.size() >= start_tokens.size() &&
|
||||
std::equal(start_tokens.begin(), start_tokens.end(), prefill_tokens.begin())) {
|
||||
initial_state = REASONING_BUDGET_COUNTING;
|
||||
// If the end sequence also follows the start in the prefill, reasoning
|
||||
// was opened and immediately closed — stay IDLE.
|
||||
if (!end_tokens.empty() &&
|
||||
prefill_tokens.size() >= start_tokens.size() + end_tokens.size()) {
|
||||
auto end_start = prefill_tokens.end() - (ptrdiff_t) end_tokens.size();
|
||||
if (end_start >= prefill_tokens.begin() + (ptrdiff_t) start_tokens.size() &&
|
||||
std::equal(end_tokens.begin(), end_tokens.end(), end_start)) {
|
||||
initial_state = REASONING_BUDGET_IDLE;
|
||||
}
|
||||
}
|
||||
}
|
||||
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
|
||||
}
|
||||
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
common_reasoning_budget_state initial_state) {
|
||||
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
|
||||
}
|
||||
|
||||
@@ -24,14 +24,26 @@ enum common_reasoning_budget_state {
|
||||
// DONE: passthrough forever
|
||||
//
|
||||
// Parameters:
|
||||
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
|
||||
// start_tokens - token sequence that activates counting
|
||||
// end_tokens - token sequence for natural deactivation
|
||||
// forced_tokens - token sequence forced when budget expires
|
||||
// budget - max tokens allowed in the reasoning block
|
||||
// initial_state - initial state of the sampler (e.g. IDLE or COUNTING)
|
||||
// note: COUNTING with budget <= 0 is promoted to FORCING
|
||||
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
|
||||
// start_tokens - token sequence that activates counting
|
||||
// end_tokens - token sequence for natural deactivation
|
||||
// forced_tokens - token sequence forced when budget expires
|
||||
// budget - max tokens allowed in the reasoning block
|
||||
// prefill_tokens - tokens already present in the prompt (generation prompt);
|
||||
// used to determine the initial state: COUNTING if they begin
|
||||
// with start_tokens (but don't also end with end_tokens),
|
||||
// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING.
|
||||
//
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
const std::vector<llama_token> & prefill_tokens = {});
|
||||
|
||||
// Variant that takes an explicit initial state (used by tests and clone).
|
||||
// COUNTING with budget <= 0 is promoted to FORCING.
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
#include "sampling.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "ggml.h"
|
||||
#include "log.h"
|
||||
#include "reasoning-budget.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
||||
// TODO: deduplicate with llama-impl.h
|
||||
@@ -189,9 +192,10 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
|
||||
std::vector<llama_sampler *> samplers;
|
||||
|
||||
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
||||
const std::string & grammar_str = common_grammar_value(params.grammar);
|
||||
if (grammar_str.compare(0, 11, "%llguidance") == 0) {
|
||||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
|
||||
grmr = llama_sampler_init_llg(vocab, "lark", grammar_str.c_str());
|
||||
#else
|
||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
@@ -240,17 +244,46 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
trigger_patterns_c.push_back(regex.c_str());
|
||||
}
|
||||
|
||||
if (!params.grammar.empty()) {
|
||||
if (!grammar_str.empty()) {
|
||||
if (params.grammar_lazy) {
|
||||
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, grammar_str.c_str(), "root",
|
||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
trigger_tokens.data(), trigger_tokens.size());
|
||||
} else {
|
||||
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||
grmr = llama_sampler_init_grammar(vocab, grammar_str.c_str(), "root");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Feed generation prompt tokens to the grammar sampler so it advances past
|
||||
// tokens the template already placed in the prompt.
|
||||
// Only applies to output-format and tool-call grammars; user-supplied grammars must not be prefilled.
|
||||
std::vector<llama_token> prefill_tokens;
|
||||
if (!params.generation_prompt.empty() && common_grammar_needs_prefill(params.grammar)) {
|
||||
GGML_ASSERT(vocab != nullptr);
|
||||
prefill_tokens = common_tokenize(vocab, params.generation_prompt, false, true);
|
||||
if (!prefill_tokens.empty()) {
|
||||
std::string first_token = common_token_to_piece(vocab, prefill_tokens[0], true);
|
||||
if (std::isspace(first_token[0]) && !std::isspace(params.generation_prompt[0])) {
|
||||
// Some tokenizers will add a space before the first special token, need to remove
|
||||
prefill_tokens = std::vector<llama_token>(prefill_tokens.begin() + 1, prefill_tokens.end());
|
||||
}
|
||||
}
|
||||
|
||||
if (grmr) {
|
||||
try {
|
||||
for (const auto & token : prefill_tokens) {
|
||||
llama_sampler_accept(grmr, token);
|
||||
LOG_DBG("%s: accepted prefill token (%d)\n", __func__, token);
|
||||
}
|
||||
} catch (std::exception &e) {
|
||||
LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__,
|
||||
common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reasoning budget sampler — added first so it can force tokens before other samplers
|
||||
if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
|
||||
samplers.push_back(common_reasoning_budget_init(
|
||||
@@ -259,7 +292,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
params.reasoning_budget_end,
|
||||
params.reasoning_budget_forced,
|
||||
params.reasoning_budget_tokens,
|
||||
params.reasoning_budget_activate_immediately ? REASONING_BUDGET_COUNTING : REASONING_BUDGET_IDLE));
|
||||
prefill_tokens));
|
||||
}
|
||||
|
||||
if (params.has_logit_bias()) {
|
||||
|
||||
@@ -31,10 +31,10 @@ import gguf
|
||||
from gguf.vocab import MistralTokenizerType, MistralVocab
|
||||
|
||||
try:
|
||||
from mistral_common.tokens.tokenizers.base import TokenizerVersion # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.base import TokenizerVersion # type: ignore[import-not-found]
|
||||
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # type: ignore[import-not-found]
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # type: ignore[import-not-found]
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import ( # type: ignore[import-not-found]
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
|
||||
@@ -45,9 +45,9 @@ except ImportError:
|
||||
_MISTRAL_COMMON_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
||||
|
||||
_mistral_common_installed = False
|
||||
TokenizerVersion = None
|
||||
Tekkenizer = None
|
||||
SentencePieceTokenizer = None
|
||||
TokenizerVersion: Any = None
|
||||
Tekkenizer: Any = None
|
||||
SentencePieceTokenizer: Any = None
|
||||
_mistral_import_error_msg = (
|
||||
"Mistral format requires `mistral-common` to be installed. Please run "
|
||||
"`pip install mistral-common[image,audio]` to install it."
|
||||
@@ -145,6 +145,7 @@ class ModelBase:
|
||||
self.model_name = model_name
|
||||
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
|
||||
self._is_nvfp4 = False
|
||||
self._is_mxfp4 = False
|
||||
|
||||
# Apply heuristics to figure out typical tensor encoding based on first tensor's dtype
|
||||
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
|
||||
@@ -220,7 +221,7 @@ class ModelBase:
|
||||
if weight_map is None or not isinstance(weight_map, dict):
|
||||
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
|
||||
tensor_names_from_index.update(weight_map.keys())
|
||||
part_dict: dict[str, None] = dict.fromkeys(weight_map.values(), None)
|
||||
part_dict: dict[str, None] = dict.fromkeys(weight_map.values(), None) # ty: ignore[invalid-assignment]
|
||||
part_names = sorted(part_dict.keys())
|
||||
else:
|
||||
weight_map = {}
|
||||
@@ -712,6 +713,7 @@ class ModelBase:
|
||||
def prepare_tensors(self):
|
||||
# detect NVFP4 quantization (ModelOpt format)
|
||||
quant_algo = (self.hparams.get("quantization_config") or {}).get("quant_algo")
|
||||
quant_method = (self.hparams.get("quantization_config") or {}).get("quant_method")
|
||||
quant_layers = (self.hparams.get("quantization_config") or {}).get("quantized_layers") or {}
|
||||
quant_config_file = self.dir_model / "hf_quant_config.json"
|
||||
|
||||
@@ -728,6 +730,7 @@ class ModelBase:
|
||||
quant_algo = "NVFP4"
|
||||
|
||||
self._is_nvfp4 = quant_algo == "NVFP4"
|
||||
self._is_mxfp4 = quant_method == "mxfp4"
|
||||
|
||||
# NVFP4 weights are repacked and written directly to gguf_writer.
|
||||
# This must run before dequant_model so NVFP4 tensors are removed
|
||||
@@ -876,6 +879,12 @@ class ModelBase:
|
||||
if self.metadata.name is None:
|
||||
self.metadata.name = self.dir_model.name
|
||||
|
||||
if self.ftype in (gguf.LlamaFileType.ALL_F32, gguf.LlamaFileType.MOSTLY_F16, gguf.LlamaFileType.MOSTLY_BF16):
|
||||
if self._is_nvfp4:
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_NVFP4
|
||||
elif self._is_mxfp4:
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_MXFP4_MOE
|
||||
|
||||
# Generate parameter weight class (useful for leader boards) if not yet determined
|
||||
if self.metadata.size_label is None and total_params > 0:
|
||||
self.metadata.size_label = gguf.size_label(total_params, shared_params, expert_params, expert_count)
|
||||
@@ -1062,6 +1071,10 @@ class TextModel(ModelBase):
|
||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||
logger.info(f"gguf: key-value head count = {n_head_kv}")
|
||||
|
||||
if self.hparams.get("is_causal") is False:
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
logger.info("gguf: causal attention = False")
|
||||
|
||||
# TODO: Handle "sliding_attention" similarly when models start implementing it
|
||||
rope_params = self.rope_parameters.get("full_attention", self.rope_parameters)
|
||||
if (rope_type := rope_params.get("rope_type")) is not None:
|
||||
@@ -4260,6 +4273,16 @@ class Qwen25OmniModel(Qwen2VLVisionModel):
|
||||
|
||||
@ModelBase.register("InternVisionModel")
|
||||
class InternVisionModel(MmprojModel):
|
||||
|
||||
min_dynamic_tiles: int = 0
|
||||
max_dynamic_tiles: int = 0
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams_vision is not None
|
||||
self.min_dynamic_tiles = self.global_config.get("min_dynamic_patch", 0)
|
||||
self.max_dynamic_tiles = self.global_config.get("max_dynamic_patch", 0)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
assert self.hparams_vision is not None
|
||||
if isinstance(self.hparams_vision['image_size'], list):
|
||||
@@ -4282,6 +4305,11 @@ class InternVisionModel(MmprojModel):
|
||||
downsample_ratio = self.global_config.get("downsample_ratio")
|
||||
assert downsample_ratio is not None
|
||||
self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio))
|
||||
# older models may not have min/max_dynamic_patch in config
|
||||
if self.min_dynamic_tiles > 0:
|
||||
self.gguf_writer.add_vision_preproc_min_tiles(self.min_dynamic_tiles)
|
||||
if self.max_dynamic_tiles > 0:
|
||||
self.gguf_writer.add_vision_preproc_max_tiles(self.max_dynamic_tiles)
|
||||
|
||||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||
if ".position_embd." in new_name:
|
||||
@@ -5878,7 +5906,7 @@ class InternLM2Model(TextModel):
|
||||
logger.error(f'Error: Missing {tokenizer_path}')
|
||||
sys.exit(1)
|
||||
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
|
||||
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
|
||||
|
||||
@@ -6199,7 +6227,7 @@ class BertModel(TextModel):
|
||||
|
||||
vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size)
|
||||
else:
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
|
||||
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
|
||||
|
||||
@@ -8876,7 +8904,7 @@ class T5Model(TextModel):
|
||||
if not tokenizer_path.is_file():
|
||||
raise FileNotFoundError(f"File not found: {tokenizer_path}")
|
||||
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
|
||||
|
||||
# some models like Pile-T5 family use BPE tokenizer instead of Unigram
|
||||
@@ -9013,7 +9041,7 @@ class T5EncoderModel(TextModel):
|
||||
if not tokenizer_path.is_file():
|
||||
raise FileNotFoundError(f"File not found: {tokenizer_path}")
|
||||
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
|
||||
|
||||
# some models like Pile-T5 family use BPE tokenizer instead of Unigram
|
||||
@@ -11121,8 +11149,7 @@ class GptOssModel(TextModel):
|
||||
|
||||
# TODO: remove once MXFP4 is supported more generally
|
||||
def dequant_model(self):
|
||||
quant_config = self.hparams.get("quantization_config")
|
||||
if quant_config is not None and quant_config.get("quant_method") == "mxfp4":
|
||||
if self._is_mxfp4:
|
||||
return
|
||||
return super().dequant_model()
|
||||
|
||||
@@ -12275,6 +12302,7 @@ class LazyTorchTensor(gguf.LazyBase):
|
||||
kwargs = {}
|
||||
|
||||
if func is torch.Tensor.numpy:
|
||||
assert len(args)
|
||||
return args[0].numpy()
|
||||
|
||||
return cls._wrap_fn(func)(*args, **kwargs)
|
||||
|
||||
@@ -112,11 +112,11 @@ class Tensor:
|
||||
(n_dims, name_len, dtype) = struct.unpack('<3I', data[offset:offset + 12])
|
||||
assert n_dims >= 0 and n_dims <= 4, f'Invalid tensor dimensions {n_dims}'
|
||||
assert name_len < 4096, 'Absurd tensor name length'
|
||||
quant = gguf.GGML_QUANT_SIZES.get(dtype)
|
||||
self.dtype = gguf.GGMLQuantizationType(dtype)
|
||||
quant = gguf.GGML_QUANT_SIZES.get(self.dtype)
|
||||
assert quant is not None, 'Unknown tensor type'
|
||||
(blksize, tysize) = quant
|
||||
offset += 12
|
||||
self.dtype= gguf.GGMLQuantizationType(dtype)
|
||||
self.dims = struct.unpack(f'<{n_dims}I', data[offset:offset + (4 * n_dims)])
|
||||
offset += 4 * n_dims
|
||||
self.name = bytes(data[offset:offset + name_len])
|
||||
|
||||
@@ -199,10 +199,13 @@ class LoraTorchTensor:
|
||||
kwargs = {}
|
||||
|
||||
if func is torch.permute:
|
||||
assert len(args)
|
||||
return type(args[0]).permute(*args, **kwargs)
|
||||
elif func is torch.reshape:
|
||||
assert len(args)
|
||||
return type(args[0]).reshape(*args, **kwargs)
|
||||
elif func is torch.stack:
|
||||
assert len(args)
|
||||
assert isinstance(args[0], Sequence)
|
||||
dim = kwargs.get("dim", 0)
|
||||
assert dim == 0
|
||||
@@ -211,6 +214,7 @@ class LoraTorchTensor:
|
||||
torch.stack([b._lora_B for b in args[0]], dim),
|
||||
)
|
||||
elif func is torch.cat:
|
||||
assert len(args)
|
||||
assert isinstance(args[0], Sequence)
|
||||
dim = kwargs.get("dim", 0)
|
||||
assert dim == 0
|
||||
@@ -362,7 +366,7 @@ if __name__ == '__main__':
|
||||
logger.error(f"Model {hparams['architectures'][0]} is not supported")
|
||||
sys.exit(1)
|
||||
|
||||
class LoraModel(model_class):
|
||||
class LoraModel(model_class): # ty: ignore[unsupported-base]
|
||||
model_arch = model_class.model_arch
|
||||
|
||||
lora_alpha: float
|
||||
|
||||
@@ -14,7 +14,7 @@ The unified auto-parser uses a pure differential, compositional approach (inspir
|
||||
**Analysis + Parser Building in Two Steps**:
|
||||
|
||||
1. `autoparser::autoparser tmpl_analysis(tmpl)` — runs all differential comparisons and populates the analysis structs
|
||||
2. `autoparser::peg_generator::generate_parser(tmpl, params, tmpl_analysis)` — uses the analysis to build a PEG parser and optional GBNF grammar
|
||||
2. `autoparser::peg_generator::generate_parser(tmpl, generation_params, tmpl_analysis)` — uses the analysis to build a PEG parser and optional GBNF grammar
|
||||
|
||||
## Data Structures
|
||||
|
||||
@@ -34,7 +34,7 @@ All structs are defined in [common/chat-auto-parser.h](common/chat-auto-parser.h
|
||||
|
||||
### `analyze_tools` and its sub-structs
|
||||
|
||||
- [common/chat-auto-parser.h:176-194](common/chat-auto-parser.h#L176-L194) — `tool_format_analysis`: `mode` enum, `section_start/end`, `per_call_start/end`, JSON field names (`function_field`, `name_field`, `args_field`, `id_field`, `gen_id_field`), and format flags (`fun_name_is_key`, `tools_array_wrapped`, `uses_python_dicts`)
|
||||
- [common/chat-auto-parser.h:176-194](common/chat-auto-parser.h#L176-L194) — `tool_format_analysis`: `mode` enum, `section_start/end`, `per_call_start/end`, JSON field names (`function_field`, `name_field`, `args_field`, `id_field`, `gen_id_field`), and format flags (`fun_name_is_key`, `tools_array_wrapped`)
|
||||
- [common/chat-auto-parser.h:196-200](common/chat-auto-parser.h#L196-L200) — `tool_function_analysis`: `name_prefix`, `name_suffix`, `close` markers around function names
|
||||
- [common/chat-auto-parser.h:202-210](common/chat-auto-parser.h#L202-L210) — `tool_arguments_analysis`: `start/end` container markers, `name_prefix/suffix`, `value_prefix/suffix`, `separator`
|
||||
- [common/chat-auto-parser.h:212-217](common/chat-auto-parser.h#L212-L217) — `tool_id_analysis`: `pos` enum, `prefix`/`suffix` markers around call ID values
|
||||
@@ -47,12 +47,21 @@ All structs are defined in [common/chat-auto-parser.h](common/chat-auto-parser.h
|
||||
| Value | Description |
|
||||
|-----------------|-----------------------------------------------------------------------------------|
|
||||
| `NONE` | No reasoning markers detected |
|
||||
| `TAG_BASED` | Standard tag-based: `<think>...</think>` |
|
||||
| `DELIMITER` | Delimiter-based: reasoning ends at a delimiter (e.g., `[BEGIN FINAL RESPONSE]`) |
|
||||
| `FORCED_OPEN` | Template ends with open reasoning tag when `enable_thinking=true` |
|
||||
| `FORCED_CLOSED` | `enable_thinking=false` emits both tags; `enable_thinking=true` emits only start |
|
||||
| `TAG_BASED` | Tag-based: `<think>...</think>` (start can be empty for delimiter-style formats) |
|
||||
| `TOOLS_ONLY` | Reasoning only appears in tool call responses, not plain content |
|
||||
|
||||
**Generation Prompt & Reasoning Prefill**: Computed in `common_chat_templates_apply_jinja` before invoking either the specialized handlers or the auto-parser, by rendering the template twice — once with `add_generation_prompt=false` and once with `add_generation_prompt=true` — and storing the diff suffix as `generation_params::generation_prompt`. This string is propagated into `common_chat_params::generation_prompt` and `common_chat_parser_params::generation_prompt`.
|
||||
|
||||
The generation prompt is prepended to model output before PEG parsing via `wrap_for_generation_prompt()`. The portion *before* the reasoning start marker (if any) is prepended as a literal to ensure any boilerplate added by the template is consumed. The full string is also fed to the grammar sampler via `llama_sampler_accept` (stored in `common_params_sampling::grammar_prefill`), advancing the grammar past tokens already in the prompt. It is used to determine the reasoning budget sampler's initial state — COUNTING if the prefill tokens begin with the reasoning start sequence (but don't also contain the end sequence), IDLE otherwise.
|
||||
|
||||
**`grammar_prefill`** (`common_params_sampling`): The generation prompt string tokenized and accepted by the grammar sampler at init time. Only applied when `grammar_external` is false (i.e., the grammar was not set explicitly by the user).
|
||||
|
||||
Three outcomes for reasoning-prefill handling (in `generate_parser()`):
|
||||
|
||||
1. **Start+end in generation prompt** (e.g. `<think></think>\n`): the parser sees reasoning as opened and immediately closed; whitespace-only reasoning content is discarded.
|
||||
2. **Only start in generation prompt** (e.g. `<think>\n`): the parser sees reasoning as already open.
|
||||
3. **Start marker present but not at the end** (e.g. Apriel's `<|begin_assistant|>` followed by boilerplate): the marker is a template artifact; the start literal is cleared so reasoning uses delimiter-style (end-only). For templates that ignore `add_generation_prompt` (empty diff), the rendered `data.prompt` is used as fallback — but only for non-TOOLS_ONLY modes, since in TOOLS_ONLY the start tag is model-generated and may appear in prior conversation turns.
|
||||
|
||||
**`content_mode`**: How the template wraps assistant content.
|
||||
|
||||
| Value | Description |
|
||||
@@ -261,16 +270,16 @@ Text is segmentized into markers and non-marker fragments using `segmentize_mark
|
||||
|
||||
- Searches `diff.right` (output with reasoning) for the reasoning content needle
|
||||
- Uses PEG parsers to find surrounding markers:
|
||||
- If both pre/post markers found in `diff.right` → `TAG_BASED` (both tags visible in diff = no forced close)
|
||||
- If both found but post marker only in the full output B → `FORCED_CLOSED`
|
||||
- If only post marker found → `DELIMITER`
|
||||
- If both pre/post markers found in `diff.right` → `TAG_BASED`
|
||||
- If both found but post marker only in the full output B → `TAG_BASED` (template forces markers; handled via prefill)
|
||||
- If only post marker found → `TAG_BASED` (delimiter-style, empty start)
|
||||
- Sets `reasoning.start` and `reasoning.end`
|
||||
|
||||
**R2 — `compare_thinking_enabled()`**: Compares `enable_thinking=false` vs `true` with a generation prompt.
|
||||
|
||||
- Detects `FORCED_OPEN`: `enable_thinking=true` adds a non-empty marker at the end of the prompt (where model will start generating) — sets `reasoning.start`, mode = `FORCED_OPEN`
|
||||
- Detects `FORCED_CLOSED`: `enable_thinking=false` produces both start+end markers; `enable_thinking=true` produces only start marker
|
||||
- Handles the reverse case: if both start and end are still empty, looks for a single-segment diff on each side to extract both markers
|
||||
- Detects template-added reasoning markers: `enable_thinking=true` appends a non-empty marker → sets `reasoning.start`, mode = `TAG_BASED`
|
||||
- Handles the reverse case (`enable_thinking=false` appends the marker instead): extracts both start (from the preceding segment) and end markers; mode = `TAG_BASED`
|
||||
- The reasoning prefill (markers added by the template) is later extracted in `common_chat_templates_apply_jinja` and prepended to model output before parsing
|
||||
|
||||
**R3 — `compare_reasoning_scope()`**: Compares assistant message with reasoning+text-content vs reasoning+tool-calls.
|
||||
|
||||
@@ -343,7 +352,7 @@ Classification logic:
|
||||
|
||||
A workaround array in `common/chat-diff-analyzer.cpp` applies post-hoc patches after analysis. Each workaround is a lambda that inspects the template source and overrides analysis results. Current workarounds:
|
||||
|
||||
1. **Old Qwen/DeepSeek thinking templates** — source contains `content.split('</think>')`: sets `reasoning.mode = FORCED_OPEN` with `<think>`/`</think>` markers if no reasoning was detected
|
||||
1. **Old Qwen/DeepSeek thinking templates** — source contains `content.split('</think>')` but not `<SPECIAL_12>`: sets `reasoning.mode = TAG_BASED` with `<think>`/`</think>` markers if no reasoning was detected
|
||||
2. **Granite 3.3** — source contains specific "Write your thoughts" text: forces `TAG_BASED` reasoning with `<think>`/`</think>` and `WRAPPED_WITH_REASONING` content with `<response>`/`</response>`
|
||||
3. **Cohere Command R+** — source contains `<|CHATBOT_TOKEN|>`: sets `ALWAYS_WRAPPED` content mode if no content start is already set
|
||||
4. **Functionary 3.1** — source contains `set has_code_interpreter`: forces `PLAIN` content, specific `per_call_start/end`, clears preserved tokens to only keep Functionary-specific markers
|
||||
@@ -355,12 +364,13 @@ Each analyzer struct (`analyze_reasoning`, `analyze_content`, `analyze_tools`) i
|
||||
|
||||
#### Reasoning Parser (`analyze_reasoning::build_parser`)
|
||||
|
||||
| Mode | Parser |
|
||||
|-----------------------------------|---------------------------------------------------------------------|
|
||||
| Not extracting reasoning | `eps()` |
|
||||
| `FORCED_OPEN` or `FORCED_CLOSED` | `reasoning(until(end)) + end` — opening tag was in the prompt |
|
||||
| `TAG_BASED` or `TOOLS_ONLY` | `optional(start + reasoning(until(end)) + end)` |
|
||||
| `DELIMITER` | `optional(reasoning(until(end)) + end)` — no start marker |
|
||||
| Mode | Parser |
|
||||
|-----------------------------------------------|---------------------------------------------------------------------------|
|
||||
| Not extracting reasoning | `eps()` |
|
||||
| `TAG_BASED` or `TOOLS_ONLY` (non-empty start) | `optional(start + reasoning(until(end)) + end + space())` |
|
||||
| `TAG_BASED` or `TOOLS_ONLY` (empty start) | `optional(reasoning(until(end)) + end + space())` — delimiter-style |
|
||||
|
||||
Note: The start marker may be empty either because the analyzer detected delimiter-style reasoning, or because `generate_parser()` cleared a template artifact start marker (see Generation Prompt & Reasoning Prefill above). Whitespace-only reasoning content (e.g. from a `<think></think>` prefill) is discarded by the mapper.
|
||||
|
||||
#### Content Parser (`analyze_content::build_parser`)
|
||||
|
||||
@@ -410,9 +420,7 @@ All three tool parsers return:
|
||||
reasoning + optional(content(until(trigger_marker))) + tool_calls + end()
|
||||
```
|
||||
|
||||
### Python Dict Format
|
||||
|
||||
When `format.uses_python_dicts` is true (detected when single-quoted strings appear in JSON argument context), `build_parser()` pre-registers a `json-string` rule that accepts both single-quoted and double-quoted strings. This is done before any `p.json()` call so all JSON parsing inherits the flexible rule.
|
||||
Each returned parser is wrapped by `wrap_for_generation_prompt()`, which prepends a literal for any boilerplate prefix of the generation prompt (the portion before the reasoning start marker).
|
||||
|
||||
## Mapper
|
||||
|
||||
@@ -421,22 +429,22 @@ When `format.uses_python_dicts` is true (detected when single-quoted strings app
|
||||
- **Buffered arguments**: Before `tool_name` is known, argument text goes to `args_buffer`; once the name is set, the buffer is flushed to `current_tool->arguments`
|
||||
- **`args_target()`**: Returns a reference to whichever destination is currently active (buffer or tool args), eliminating branching
|
||||
- **`closing_quote_pending`**: Tracks whether a closing `"` needs to be appended when a string argument value is finalized (for schema-declared string types in tagged format)
|
||||
- **Quote normalization**: Python-style quotes (`'key': 'value'`) are converted to JSON (`"key": "value"`)
|
||||
- **Whitespace-only reasoning**: Reasoning content that consists entirely of whitespace (e.g. from a `<think></think>` prefill) is cleared so the message shows no reasoning
|
||||
- **Brace auto-closing**: At tool close, unclosed `{` braces are closed automatically
|
||||
|
||||
## Files
|
||||
|
||||
| File | Purpose |
|
||||
|-------------------------------------------|----------------------------------------------------------------------|
|
||||
| `common/chat-auto-parser.h` | All analysis structs, enums, `autoparser`, `peg_generator`, `templates_params` |
|
||||
| `common/chat-auto-parser-generator.cpp` | Parser generator: `generate_parser()` and `build_parser()` methods |
|
||||
| `common/chat-diff-analyzer.cpp` | Differential analysis implementation and workarounds |
|
||||
| `common/chat-auto-parser-helpers.h/cpp` | `calculate_diff_split()`, `segmentize_markers()`, |
|
||||
| | `compare_variants()`, string helpers |
|
||||
| `common/chat-peg-parser.h/cpp` | `common_chat_peg_builder`, `common_chat_peg_mapper`, and helpers |
|
||||
| `common/chat.cpp` | Entry point: `common_chat_templates_apply_jinja()` |
|
||||
| `tools/parser/debug-template-parser.cpp` | Debug tool for template analysis |
|
||||
| `tools/parser/template-analysis.cpp` | Template analysis tool |
|
||||
| File | Purpose |
|
||||
|-------------------------------------------|---------------------------------------------------------------------------------|
|
||||
| `common/chat-auto-parser.h` | All analysis structs, enums, `autoparser`, `peg_generator`, `generation_params` |
|
||||
| `common/chat-auto-parser-generator.cpp` | Parser generator: `generate_parser()` and `build_parser()` methods |
|
||||
| `common/chat-diff-analyzer.cpp` | Differential analysis implementation and workarounds |
|
||||
| `common/chat-auto-parser-helpers.h/cpp` | `calculate_diff_split()`, `segmentize_markers()`, `compare_variants()`, |
|
||||
| | `wrap_for_generation_prompt()`, string helpers |
|
||||
| `common/chat-peg-parser.h/cpp` | `common_chat_peg_builder`, `common_chat_peg_mapper`, and helpers |
|
||||
| `common/chat.cpp` | Entry point: `common_chat_templates_apply_jinja()` |
|
||||
| `tools/parser/debug-template-parser.cpp` | Debug tool for template analysis |
|
||||
| `tools/parser/template-analysis.cpp` | Template analysis tool |
|
||||
|
||||
## Testing & Debugging
|
||||
|
||||
@@ -516,10 +524,10 @@ To support a new template format:
|
||||
|
||||
## Edge Cases and Quirks
|
||||
|
||||
1. **Forced Thinking**: When `enable_thinking=true` and the model prompt ends with an open reasoning tag (e.g., `<think>`), the parser enters forced thinking mode and immediately expects reasoning content without waiting for a start marker.
|
||||
1. **Generation Prompt & Reasoning Prefill**: The generation prompt is extracted by diffing `add_generation_prompt=false` vs `true` in `common_chat_templates_apply_jinja`, so it contains exactly what the template appends — avoiding false positives from prior conversation turns.
|
||||
2. **Per-Call vs Per-Section Markers**: Some templates wrap each tool call individually (`per_call_start/end`); others wrap the entire section (`section_start/end`). T2 (`check_per_call_markers()`) disambiguates by checking if the second call in a two-call output starts with the section marker.
|
||||
3. **Python Dict Format**: The Seed template family uses single-quoted JSON (`'key': 'value'`). The `uses_python_dicts` flag causes the PEG builder to register a flexible `json-string` rule accepting both quote styles before any JSON rules are built.
|
||||
4. **Tag Boundary Fixing**: `calculate_diff_split()` iteratively adjusts prefix/suffix boundaries to avoid splitting `<tag>` or `[marker]` tokens, ensuring clean extraction.
|
||||
5. **Call ID Side Effects**: When a call ID is detected, `per_call_end` may have been incorrectly set to include the call ID suffix. T7 clears `per_call_end` in this case.
|
||||
6. **Tool Analysis Gating**: `analyze_tools` is only constructed (and all tool analysis phases run) when `jinja_caps.supports_tool_calls` is true. Within tool analysis, `check_per_call_markers()` (T2) only runs if `jinja_caps.supports_parallel_tool_calls`.
|
||||
7. **`analyze_arguments()` Gating**: Within tool analysis, A1 and A2 (argument name/value marker extraction) only run for `TAG_WITH_TAGGED` format. `extract_argument_separator()` and `extract_args_markers()` run for all non-`JSON_NATIVE` formats.
|
||||
3. **Tag Boundary Fixing**: `calculate_diff_split()` iteratively adjusts prefix/suffix boundaries to avoid splitting `<tag>` or `[marker]` tokens, ensuring clean extraction.
|
||||
4. **Call ID Side Effects**: When a call ID is detected, `per_call_end` may have been incorrectly set to include the call ID suffix. T7 clears `per_call_end` in this case.
|
||||
5. **Tool Analysis Gating**: `analyze_tools` is only constructed (and all tool analysis phases run) when `jinja_caps.supports_tool_calls` is true. Within tool analysis, `check_per_call_markers()` (T2) only runs if `jinja_caps.supports_parallel_tool_calls`.
|
||||
6. **`analyze_arguments()` Gating**: Within tool analysis, A1 and A2 (argument name/value marker extraction) only run for `TAG_WITH_TAGGED` format. `extract_argument_separator()` and `extract_args_markers()` run for all non-`JSON_NATIVE` formats.
|
||||
7. **Undetected Tool Format**: If `analyze_tools` concludes tool calling is supported but cannot determine the format, `build_parser()` logs an error and returns `eps()` (graceful degradation) rather than aborting.
|
||||
|
||||
70
docs/ops.md
70
docs/ops.md
@@ -12,9 +12,9 @@ Legend:
|
||||
- 🟡 Partially supported by this backend
|
||||
- ❌ Not supported by this backend
|
||||
|
||||
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | WebGPU | ZenDNN | zDNN |
|
||||
| Operation | BLAS | CANN | CPU | CUDA | MTL | OpenCL | SYCL | Vulkan | WebGPU | ZenDNN | zDNN |
|
||||
|-----------|------|------|------|------|------|------|------|------|------|------|------|
|
||||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| ABS | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
@@ -23,63 +23,63 @@ Legend:
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DIAG | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DIAG | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| GELU | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| OUT_PROD | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | 🟡 |
|
||||
| PAD | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_1D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_1D | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
@@ -91,31 +91,31 @@ Legend:
|
||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| SET_ROWS | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SUM | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
|
||||
32655
docs/ops/Metal.csv
32655
docs/ops/Metal.csv
File diff suppressed because it is too large
Load Diff
8714
docs/ops/WebGPU.csv
8714
docs/ops/WebGPU.csv
File diff suppressed because it is too large
Load Diff
@@ -28,9 +28,6 @@ def _build_repetition(item_rule, min_items, max_items, separator_rule=None):
|
||||
return f'({result})?' if min_items == 0 else result
|
||||
|
||||
def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True):
|
||||
has_min = min_value != None
|
||||
has_max = max_value != None
|
||||
|
||||
def digit_range(from_char: str, to_char: str):
|
||||
out.append("[")
|
||||
if from_char == to_char:
|
||||
@@ -106,7 +103,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
|
||||
out.append(to_str[i])
|
||||
out.append("]")
|
||||
|
||||
if has_min and has_max:
|
||||
if min_value is not None and max_value is not None:
|
||||
if min_value < 0 and max_value < 0:
|
||||
out.append("\"-\" (")
|
||||
_generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True)
|
||||
@@ -133,7 +130,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
|
||||
|
||||
less_decimals = max(decimals_left - 1, 1)
|
||||
|
||||
if has_min:
|
||||
if min_value is not None:
|
||||
if min_value < 0:
|
||||
out.append("\"-\" (")
|
||||
_generate_min_max_int(None, -min_value, out, decimals_left, top_level=False)
|
||||
@@ -177,7 +174,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
|
||||
more_digits(length - 1, less_decimals)
|
||||
return
|
||||
|
||||
if has_max:
|
||||
if max_value is not None:
|
||||
if max_value >= 0:
|
||||
if top_level:
|
||||
out.append("\"-\" [1-9] ")
|
||||
|
||||
@@ -64,7 +64,7 @@ def load_model_and_tokenizer(model_path, use_sentence_transformers=False, device
|
||||
print("Using SentenceTransformer to apply all numbered layers")
|
||||
model = SentenceTransformer(model_path)
|
||||
tokenizer = model.tokenizer
|
||||
config = model[0].auto_model.config # type: ignore
|
||||
config = model[0].auto_model.config
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
@@ -108,8 +108,8 @@ def load_model_and_tokenizer(model_path, use_sentence_transformers=False, device
|
||||
print(f"Model file: {type(model).__module__}")
|
||||
|
||||
# Verify the model is using the correct sliding window
|
||||
if hasattr(model.config, 'sliding_window'): # type: ignore
|
||||
print(f"Model's sliding_window: {model.config.sliding_window}") # type: ignore
|
||||
if hasattr(model.config, 'sliding_window'):
|
||||
print(f"Model's sliding_window: {model.config.sliding_window}")
|
||||
else:
|
||||
print("Model config does not have sliding_window attribute")
|
||||
|
||||
@@ -152,7 +152,7 @@ def main():
|
||||
device = next(model.parameters()).device
|
||||
else:
|
||||
# For SentenceTransformer, get device from the underlying model
|
||||
device = next(model[0].auto_model.parameters()).device # type: ignore
|
||||
device = next(model[0].auto_model.parameters()).device
|
||||
|
||||
model_name = os.path.basename(model_path)
|
||||
|
||||
@@ -177,7 +177,7 @@ def main():
|
||||
print(f"{token_id:6d} -> '{token_str}'")
|
||||
|
||||
print(f"Embeddings shape (after all SentenceTransformer layers): {all_embeddings.shape}")
|
||||
print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}") # type: ignore
|
||||
print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}")
|
||||
else:
|
||||
# Standard approach: use base model output only
|
||||
encoded = tokenizer(
|
||||
@@ -205,12 +205,12 @@ def main():
|
||||
print(f"Embedding dimension: {all_embeddings.shape[1]}")
|
||||
|
||||
if len(all_embeddings.shape) == 1:
|
||||
n_embd = all_embeddings.shape[0] # type: ignore
|
||||
n_embd = all_embeddings.shape[0]
|
||||
n_embd_count = 1
|
||||
all_embeddings = all_embeddings.reshape(1, -1)
|
||||
else:
|
||||
n_embd = all_embeddings.shape[1] # type: ignore
|
||||
n_embd_count = all_embeddings.shape[0] # type: ignore
|
||||
n_embd = all_embeddings.shape[1]
|
||||
n_embd_count = all_embeddings.shape[0]
|
||||
|
||||
print()
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from common import compare_tokens # type: ignore
|
||||
from common import compare_tokens # type: ignore[import-not-found]
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
|
||||
@@ -6,7 +6,7 @@ import re
|
||||
from copy import copy
|
||||
from enum import Enum
|
||||
from inspect import getdoc, isclass
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
from docstring_parser import parse
|
||||
from pydantic import BaseModel, create_model
|
||||
@@ -1158,7 +1158,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
|
||||
|
||||
# Assert that the parameter has a type annotation
|
||||
if param.annotation == inspect.Parameter.empty:
|
||||
raise TypeError(f"Parameter '{param.name}' in function '{func.__name__}' lacks a type annotation")
|
||||
raise TypeError(f"""Parameter '{param.name}' in function '{getattr(func, "__name__", "")}' lacks a type annotation""")
|
||||
|
||||
# Find the parameter's description in the docstring
|
||||
param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
|
||||
@@ -1166,7 +1166,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
|
||||
# Assert that the parameter has a description
|
||||
if not param_doc or not param_doc.description:
|
||||
raise ValueError(
|
||||
f"Parameter '{param.name}' in function '{func.__name__}' lacks a description in the docstring")
|
||||
f"""Parameter '{param.name}' in function '{getattr(func, "__name__", "")}' lacks a description in the docstring""")
|
||||
|
||||
# Add parameter details to the schema
|
||||
param_docs.append((param.name, param_doc))
|
||||
@@ -1177,7 +1177,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
|
||||
dynamic_fields[param.name] = (
|
||||
param.annotation if param.annotation != inspect.Parameter.empty else str, default_value)
|
||||
# Creating the dynamic model
|
||||
dynamic_model = create_model(f"{func.__name__}", **dynamic_fields)
|
||||
dynamic_model = create_model(f"{getattr(func, '__name__')}", **dynamic_fields)
|
||||
|
||||
for name, param_doc in param_docs:
|
||||
dynamic_model.model_fields[name].description = param_doc.description
|
||||
@@ -1285,7 +1285,7 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name:
|
||||
if items != {}:
|
||||
array = {"properties": items}
|
||||
array_type = convert_dictionary_to_pydantic_model(array, f"{model_name}_{field_name}_items")
|
||||
fields[field_name] = (List[array_type], ...)
|
||||
fields[field_name] = (list[array_type], ...) # ty: ignore[invalid-type-form]
|
||||
else:
|
||||
fields[field_name] = (list, ...)
|
||||
elif field_type == "object":
|
||||
|
||||
@@ -1788,9 +1788,11 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src0 = dst->src[0]; // src
|
||||
ggml_tensor * src1 = dst->src[1]; // index
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16
|
||||
|| dst->type == GGML_TYPE_BF16);
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
if (src0->type == dst->type) {
|
||||
@@ -1881,6 +1883,7 @@ void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
break;
|
||||
}
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
|
||||
ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t));
|
||||
@@ -1891,7 +1894,7 @@ void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
|
||||
}
|
||||
acl_tensor_ptr src_trans_tensor = ggml_cann_create_tensor(
|
||||
src_trans_buffer, ACL_FLOAT16, ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);
|
||||
src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);
|
||||
aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type));
|
||||
aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1,
|
||||
dst->type);
|
||||
@@ -1965,7 +1968,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor *
|
||||
|
||||
// Only check env once.
|
||||
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
|
||||
if (weight_to_nz && is_matmul_weight(weight)) {
|
||||
if (weight_to_nz && weight->type != GGML_TYPE_BF16 && is_matmul_weight(weight)) {
|
||||
acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
|
||||
} else {
|
||||
acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
|
||||
@@ -2146,6 +2149,9 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
#ifndef ASCEND_310P
|
||||
case GGML_TYPE_BF16:
|
||||
#endif
|
||||
ggml_cann_mat_mul_fp(ctx, dst);
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
|
||||
@@ -1234,7 +1234,8 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
|
||||
if (!need_transform(tensor->type)) {
|
||||
ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
|
||||
if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
|
||||
if (weight_to_nz && tensor->type != GGML_TYPE_BF16
|
||||
&& is_matmul_weight((const ggml_tensor *) tensor)) {
|
||||
GGML_ASSERT(tensor->ne[2] == 1);
|
||||
GGML_ASSERT(tensor->ne[3] == 1);
|
||||
weight_format_to_nz(tensor, offset, ctx->device);
|
||||
@@ -1443,7 +1444,8 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_t
|
||||
if (ne0 % MATRIX_ROW_PADDING != 0) {
|
||||
size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
|
||||
}
|
||||
} else if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
|
||||
} else if (weight_to_nz && tensor->type != GGML_TYPE_BF16
|
||||
&& is_matmul_weight((const ggml_tensor *) tensor)) {
|
||||
// NZ format weight are not support quantized yet.
|
||||
// If ND tensor transform to NZ, size may changed.
|
||||
int64_t shape[] = { tensor->ne[1], tensor->ne[0] };
|
||||
@@ -2283,6 +2285,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
switch (op->src[0]->type) {
|
||||
#ifndef ASCEND_310P
|
||||
case GGML_TYPE_BF16:
|
||||
#endif
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
return true;
|
||||
@@ -2320,6 +2325,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
#ifndef ASCEND_310P
|
||||
case GGML_TYPE_BF16:
|
||||
#endif
|
||||
case GGML_TYPE_Q8_0:
|
||||
return true;
|
||||
default:
|
||||
@@ -2332,6 +2340,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
switch (op->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
#ifndef ASCEND_310P
|
||||
case GGML_TYPE_BF16:
|
||||
#endif
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@@ -2341,20 +2352,30 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
case GGML_OP_CPY:
|
||||
{
|
||||
ggml_tensor * src = op->src[0];
|
||||
#ifdef ASCEND_310P
|
||||
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
|
||||
(src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16)) {
|
||||
// only support F32 and F16.
|
||||
// only support F32 and F16 on 310P.
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_BF16) ||
|
||||
(src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16 && src->type != GGML_TYPE_BF16)) {
|
||||
// only support F32, F16 and BF16.
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_CONT:
|
||||
{
|
||||
// TODO: support GGML_TYPE_BF16
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
#ifndef ASCEND_310P
|
||||
case GGML_TYPE_BF16:
|
||||
#endif
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
||||
@@ -572,9 +572,11 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
|
||||
set(KLEIDIAI_FETCH_ARGS
|
||||
URL ${KLEIDIAI_DOWNLOAD_URL}
|
||||
DOWNLOAD_EXTRACT_TIMESTAMP NEW
|
||||
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5}
|
||||
)
|
||||
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
|
||||
list(APPEND KLEIDIAI_FETCH_ARGS DOWNLOAD_EXTRACT_TIMESTAMP NEW)
|
||||
endif()
|
||||
|
||||
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.28")
|
||||
FetchContent_Declare(KleidiAI_Download
|
||||
|
||||
@@ -3194,6 +3194,7 @@ class tinyBLAS_PPC {
|
||||
|
||||
private:
|
||||
|
||||
__attribute__((always_inline))
|
||||
inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
||||
vec_t vec_C[4];
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
@@ -3204,6 +3205,7 @@ class tinyBLAS_PPC {
|
||||
}
|
||||
}
|
||||
|
||||
__attribute__((always_inline))
|
||||
inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
||||
vec_t vec_C[4];
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
|
||||
@@ -116,12 +116,11 @@ if (CUDAToolkit_FOUND)
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
|
||||
else()
|
||||
file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
list(APPEND GGML_SOURCES_CUDA
|
||||
template-instances/fattn-vec-instance-f16-f16.cu
|
||||
template-instances/fattn-vec-instance-q4_0-q4_0.cu
|
||||
template-instances/fattn-vec-instance-q8_0-q8_0.cu
|
||||
template-instances/fattn-vec-instance-bf16-bf16.cu)
|
||||
endif()
|
||||
|
||||
ggml_add_backend_library(ggml-cuda
|
||||
|
||||
@@ -41,6 +41,16 @@ template<typename dst_t, typename src_t>
|
||||
return __bfloat162float(x);
|
||||
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
|
||||
return __float22half2_rn(x);
|
||||
} else if constexpr(std::is_same_v<src_t, nv_bfloat162> && std::is_same_v<dst_t, float2>) {
|
||||
#ifdef GGML_USE_HIP
|
||||
return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x)));
|
||||
#else
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
return __bfloat1622float2(x);
|
||||
#else
|
||||
return make_float2(__bfloat162float(x.x), __bfloat162float(x.y));
|
||||
#endif // __CUDA_ARCH__ >= 800
|
||||
#endif // GGML_USE_HIP
|
||||
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
|
||||
// bypass compile error on cuda 12.0.1
|
||||
#ifdef GGML_USE_HIP
|
||||
|
||||
@@ -74,6 +74,37 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
|
||||
return sum;
|
||||
}
|
||||
|
||||
template <int D, int nthreads>
|
||||
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_bf16(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
|
||||
|
||||
const nv_bfloat162 * K_bf16 = (const nv_bfloat162 *) K_c;
|
||||
GGML_UNUSED(Q_q8);
|
||||
GGML_UNUSED(Q_ds_v);
|
||||
|
||||
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
|
||||
constexpr int cpy_ne = cpy_nb / 4;
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
|
||||
__align__(16) nv_bfloat162 tmp[cpy_ne];
|
||||
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_bf16 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
// FIXME replace macros in vector FA kernel with templating and use FP32 for BF16
|
||||
ggml_cuda_mad(sum, ggml_cuda_cast<float2>(tmp[k_KQ_1]), __half22float2(((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]));
|
||||
#else
|
||||
ggml_cuda_mad(sum, ggml_cuda_cast<float2>(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<int D, int nthreads>
|
||||
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
@@ -321,6 +352,19 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int ne>
|
||||
static __device__ __forceinline__ void dequantize_V_bf16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
||||
static_assert(std::is_same_v<T, float>, "BF16 V dequantization only supports float output");
|
||||
static_assert(ne % 2 == 0, "bad ne");
|
||||
__align__(16) nv_bfloat162 tmp[ne/2];
|
||||
ggml_cuda_memcpy_1<ne*sizeof(nv_bfloat16)>(tmp, (const nv_bfloat16 *) vx + i0);
|
||||
float2 * dst_f2 = (float2 *) dst;
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne/2; ++l) {
|
||||
dst_f2[l] = ggml_cuda_cast<float2>(tmp[l]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int ne>
|
||||
static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
@@ -547,6 +591,8 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
|
||||
return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>;
|
||||
} else if constexpr (type_K == GGML_TYPE_Q8_0) {
|
||||
return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
|
||||
} else if constexpr (type_K == GGML_TYPE_BF16) {
|
||||
return vec_dot_fattn_vec_KQ_bf16<D, nthreads>;
|
||||
} else {
|
||||
static_assert(type_K == -1, "bad type");
|
||||
return nullptr;
|
||||
@@ -567,6 +613,8 @@ constexpr __device__ dequantize_V_t get_dequantize_V() {
|
||||
return dequantize_V_q5_1<T, ne>;
|
||||
} else if constexpr (type_V == GGML_TYPE_Q8_0) {
|
||||
return dequantize_V_q8_0<T, ne>;
|
||||
} else if constexpr (type_V == GGML_TYPE_BF16) {
|
||||
return dequantize_V_bf16<float, ne>;
|
||||
} else {
|
||||
static_assert(type_V == -1, "bad type");
|
||||
return nullptr;
|
||||
|
||||
@@ -75,17 +75,17 @@ static __global__ void flash_attn_ext_vec(
|
||||
#endif // GGML_USE_HIP
|
||||
|
||||
constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device();
|
||||
constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q;
|
||||
constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q;
|
||||
constexpr int nthreads_KQ = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_KQ_q;
|
||||
constexpr int nthreads_V = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_V_q;
|
||||
|
||||
static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K");
|
||||
static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V");
|
||||
|
||||
constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4;
|
||||
constexpr int V_rows_per_thread = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 2*cpy_ne : 4;
|
||||
constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V;
|
||||
|
||||
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
|
||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16 && type_K != GGML_TYPE_BF16;
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
|
||||
#else
|
||||
@@ -323,8 +323,18 @@ static __global__ void flash_attn_ext_vec(
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
|
||||
half2 tmp[V_rows_per_thread/2];
|
||||
dequantize_V(V + k*nb21, tmp,
|
||||
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
|
||||
if constexpr (type_V == GGML_TYPE_BF16) {
|
||||
float2 tmp_f[V_rows_per_thread/2];
|
||||
dequantize_V(V + k*nb21, tmp_f,
|
||||
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
|
||||
#pragma unroll
|
||||
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
|
||||
tmp[i_VKQ_1] = __float22half2_rn(tmp_f[i_VKQ_1]);
|
||||
}
|
||||
} else {
|
||||
dequantize_V(V + k*nb21, tmp,
|
||||
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
|
||||
#pragma unroll
|
||||
@@ -563,6 +573,7 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_BF16); \
|
||||
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)
|
||||
@@ -570,6 +581,7 @@ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_BF16)
|
||||
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)
|
||||
@@ -577,6 +589,7 @@ EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_BF16)
|
||||
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)
|
||||
@@ -584,3 +597,4 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_BF16)
|
||||
|
||||
@@ -224,6 +224,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_F16)
|
||||
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
||||
@@ -231,6 +232,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_0)
|
||||
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
|
||||
@@ -238,6 +240,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_1)
|
||||
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
|
||||
@@ -245,6 +248,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_0)
|
||||
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
|
||||
@@ -252,6 +256,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_1)
|
||||
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
|
||||
@@ -259,10 +264,20 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q8_0)
|
||||
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_BF16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_BF16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_BF16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_BF16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_BF16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16)
|
||||
#else
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16)
|
||||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -355,6 +370,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_BF16:
|
||||
break;
|
||||
default:
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
|
||||
@@ -33,7 +33,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
|
||||
static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
|
||||
case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
|
||||
@@ -173,11 +173,11 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d
|
||||
return 1;
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) {
|
||||
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) {
|
||||
if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
|
||||
switch (ncols_dst) {
|
||||
case 1:
|
||||
return 1;
|
||||
return small_k ? nwarps : 1;
|
||||
case 2:
|
||||
case 3:
|
||||
case 4:
|
||||
@@ -193,7 +193,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false>
|
||||
template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false, bool small_k = false>
|
||||
__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
|
||||
static __global__ void mul_mat_vec_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
|
||||
@@ -208,7 +208,7 @@ static __global__ void mul_mat_vec_q(
|
||||
constexpr int vdr = get_vdr_mmvq(type);
|
||||
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
|
||||
constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id);
|
||||
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
|
||||
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
|
||||
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
||||
@@ -414,14 +414,16 @@ static __global__ void mul_mat_vec_q(
|
||||
template<ggml_type type>
|
||||
static std::pair<dim3, dim3> calc_launch_params(
|
||||
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
|
||||
const int warp_size, const mmvq_parameter_table_id table_id) {
|
||||
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
|
||||
const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false) {
|
||||
const int nwarps = calc_nwarps(type, ncols_dst, table_id);
|
||||
const int rpb = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
|
||||
const int64_t nblocks = (nrows_x + rpb - 1) / rpb;
|
||||
const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
|
||||
const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1);
|
||||
const dim3 block_dims(warp_size, nwarps, 1);
|
||||
return {block_nums, block_dims};
|
||||
}
|
||||
|
||||
template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false>
|
||||
template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false, bool small_k = false>
|
||||
static void mul_mat_vec_q_switch_fusion(
|
||||
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
|
||||
@@ -434,7 +436,7 @@ static void mul_mat_vec_q_switch_fusion(
|
||||
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||
if constexpr (c_ncols_dst == 1) {
|
||||
if (has_fusion) {
|
||||
mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
|
||||
@@ -444,7 +446,7 @@ static void mul_mat_vec_q_switch_fusion(
|
||||
|
||||
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
|
||||
|
||||
mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
|
||||
@@ -488,11 +490,33 @@ static void mul_mat_vec_q_switch_ncols_dst(
|
||||
switch (ncols_dst) {
|
||||
case 1: {
|
||||
constexpr int c_ncols_dst = 1;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, ids_stride, stream);
|
||||
|
||||
// When K is small, increase rows_per_block to match nwarps so each warp has more work to do
|
||||
// Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle.
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
||||
constexpr int vdr = get_vdr_mmvq(type);
|
||||
const int blocks_per_row_x = ncols_x / qk;
|
||||
const int blocks_per_iter_1warp = vdr * warp_size / qi;
|
||||
const int nwarps = calc_nwarps(type, c_ncols_dst, table_id);
|
||||
const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
|
||||
if (use_small_k) {
|
||||
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
|
||||
warp_size, table_id, true);
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, false, true>(
|
||||
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, ids_stride, stream);
|
||||
} else {
|
||||
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
|
||||
warp_size, table_id);
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(
|
||||
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, ids_stride, stream);
|
||||
}
|
||||
} break;
|
||||
case 2: {
|
||||
constexpr int c_ncols_dst = 2;
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_BF16);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_F16);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_1);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_1);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q8_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_BF16);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_BF16);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_BF16);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_BF16);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_BF16);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_BF16);
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
|
||||
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]
|
||||
|
||||
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]
|
||||
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"]
|
||||
|
||||
SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ static int opt_verbose = 0;
|
||||
static int opt_profile = 0;
|
||||
static int opt_hostbuf = 1; // hostbuf ON by default
|
||||
static int opt_experimental = 0;
|
||||
static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only
|
||||
|
||||
// Enable all stages by default
|
||||
static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE;
|
||||
@@ -1693,7 +1694,7 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
|
||||
// Start the DSP-side service. We need to pass the queue ID to the
|
||||
// DSP in a FastRPC call; the DSP side will import the queue and start
|
||||
// listening for packets in a callback.
|
||||
err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx);
|
||||
err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx);
|
||||
if (err != 0) {
|
||||
GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err);
|
||||
throw std::runtime_error("ggml-hex: iface start failed (see log for details)");
|
||||
@@ -3372,6 +3373,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
|
||||
const char * str_etm = getenv("GGML_HEXAGON_ETM");
|
||||
const char * str_nhvx = getenv("GGML_HEXAGON_NHVX");
|
||||
const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX");
|
||||
const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
|
||||
const char * str_arch = getenv("GGML_HEXAGON_ARCH");
|
||||
|
||||
@@ -3381,8 +3383,9 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask;
|
||||
opt_opsync = str_opsync ? atoi(str_opsync) : 0;
|
||||
opt_profile = str_profile ? atoi(str_profile) : 0;
|
||||
opt_etm = str_etm ? atoi(str_etm) : 0;
|
||||
opt_etm = str_etm ? atoi(str_etm) : 0;
|
||||
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
|
||||
opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx;
|
||||
opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev;
|
||||
|
||||
if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) {
|
||||
|
||||
@@ -40,6 +40,24 @@ target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
|
||||
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
|
||||
|
||||
# HMX acceleration: available on v73+ architectures
|
||||
set(HTP_HMX_VERSIONS v73 v75 v79 v81)
|
||||
list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
|
||||
|
||||
if (_hmx_idx GREATER_EQUAL 0)
|
||||
target_sources(${HTP_LIB} PRIVATE
|
||||
hmx-matmul-ops.c
|
||||
)
|
||||
|
||||
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
|
||||
set_source_files_properties(
|
||||
hmx-matmul-ops.c
|
||||
PROPERTIES COMPILE_OPTIONS "-mhmx"
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1)
|
||||
endif()
|
||||
|
||||
build_idl(htp_iface.idl ${HTP_LIB})
|
||||
|
||||
set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
@@ -175,6 +175,86 @@ static inline uint32_t dma_queue_capacity(dma_queue * q) {
|
||||
return q->capacity;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Overflow-safe DMA push: all UDMA type1 descriptor fields (roiwidth,
|
||||
// roiheight, srcstride, dststride) are 16-bit, max 65535. This helper
|
||||
// transparently handles values that exceed the 16-bit limit and submits
|
||||
// chained DMA transtions.
|
||||
//
|
||||
// Case 1 (fast path): all params fit in 16 bits -> direct dma_queue_push.
|
||||
// Case 2 (contiguous block): width == srcstride == dststride. Reshape the
|
||||
// flat transfer into a 2D descriptor with sub_width <= 65535. Produces a
|
||||
// single descriptor, preserving async DMA behavior.
|
||||
// Case 3 (stride overflow): srcstride or dststride > 65535. Issue rows
|
||||
// one at a time. The first N-1 rows are pushed+popped synchronously;
|
||||
// the last row is left async so the caller can pop it.
|
||||
// ---------------------------------------------------------------------------
|
||||
#define UDMA_MAX_FIELD_VAL 65535u
|
||||
|
||||
static inline bool dma_queue_push_chained(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t width, size_t nrows) {
|
||||
// Fast path: everything fits in 16 bits.
|
||||
if (__builtin_expect(
|
||||
width <= UDMA_MAX_FIELD_VAL &&
|
||||
nrows <= UDMA_MAX_FIELD_VAL &&
|
||||
src_stride <= UDMA_MAX_FIELD_VAL &&
|
||||
dst_stride <= UDMA_MAX_FIELD_VAL, 1)) {
|
||||
return dma_queue_push(q, dptr, dst_stride, src_stride, width, nrows);
|
||||
}
|
||||
|
||||
// Case 2: contiguous block (width == src_stride == dst_stride).
|
||||
// Reshape total bytes into sub_width * sub_nrows where sub_width <= 65535.
|
||||
if (width == src_stride && width == dst_stride) {
|
||||
size_t total = width * nrows;
|
||||
|
||||
// Pick the largest 128-byte-aligned sub_width that divides total evenly.
|
||||
size_t sub_width = UDMA_MAX_FIELD_VAL & ~(size_t)127; // 65408
|
||||
while (sub_width > 0 && total % sub_width != 0) {
|
||||
sub_width -= 128;
|
||||
}
|
||||
if (sub_width == 0) {
|
||||
// Fallback: use original width (must fit) with adjusted nrows.
|
||||
// This shouldn't happen for 128-aligned DMA sizes.
|
||||
sub_width = width;
|
||||
}
|
||||
size_t sub_nrows = total / sub_width;
|
||||
|
||||
// Handle sub_nrows > 65535 by issuing chunked descriptors.
|
||||
const uint8_t *src = (const uint8_t *)dptr.src;
|
||||
uint8_t *dst = (uint8_t *)dptr.dst;
|
||||
size_t rows_done = 0;
|
||||
while (rows_done < sub_nrows) {
|
||||
size_t chunk = sub_nrows - rows_done;
|
||||
if (chunk > UDMA_MAX_FIELD_VAL) chunk = UDMA_MAX_FIELD_VAL;
|
||||
|
||||
dma_ptr p = dma_make_ptr(dst + rows_done * sub_width, src + rows_done * sub_width);
|
||||
if (!dma_queue_push(q, p, sub_width, sub_width, sub_width, chunk))
|
||||
return false;
|
||||
|
||||
rows_done += chunk;
|
||||
// Complete all chunks without waiting except the last one, so the
|
||||
// caller's single dma_queue_pop drains the final descriptor.
|
||||
if (rows_done < sub_nrows)
|
||||
dma_queue_pop_nowait(q);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Case 3: stride overflow — fall back to row-by-row.
|
||||
{
|
||||
const uint8_t *src = (const uint8_t *)dptr.src;
|
||||
uint8_t *dst = (uint8_t *)dptr.dst;
|
||||
for (size_t r = 0; r < nrows; ++r) {
|
||||
dma_ptr p = dma_make_ptr(dst + r * dst_stride,
|
||||
src + r * src_stride);
|
||||
if (!dma_queue_push(q, p, 0, 0, width, 1))
|
||||
return false;
|
||||
if (r + 1 < nrows)
|
||||
dma_queue_pop_nowait(q);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
||||
@@ -29,10 +29,22 @@ static inline uint64_t hex_get_pktcnt() {
|
||||
return pktcnt;
|
||||
}
|
||||
|
||||
static inline int32_t hex_is_aligned(void * addr, uint32_t align) {
|
||||
static inline size_t hmx_ceil_div(size_t num, size_t den) {
|
||||
return (num + den - 1) / den;
|
||||
}
|
||||
|
||||
static inline int32_t hex_is_aligned(const void * addr, uint32_t align) {
|
||||
return ((size_t) addr & (align - 1)) == 0;
|
||||
}
|
||||
|
||||
static inline size_t hex_align_up(size_t v, size_t align) {
|
||||
return hmx_ceil_div(v, align) * align;
|
||||
}
|
||||
|
||||
static inline size_t hex_align_down(size_t v, size_t align) {
|
||||
return (v / align) * align;
|
||||
}
|
||||
|
||||
static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
|
||||
uint32_t left_off = (size_t) addr & (chunk_size - 1);
|
||||
uint32_t right_off = left_off + n;
|
||||
@@ -43,6 +55,14 @@ static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
|
||||
return m * ((n + m - 1) / m);
|
||||
}
|
||||
|
||||
static inline size_t hex_smin(size_t a, size_t b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
static inline size_t hex_smax(size_t a, size_t b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
|
||||
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
|
||||
Q6_l2fetch_AP((void *) p, control);
|
||||
|
||||
1528
ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c
Normal file
1528
ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c
Normal file
File diff suppressed because it is too large
Load Diff
72
ggml/src/ggml-hexagon/htp/hmx-ops.h
Normal file
72
ggml/src/ggml-hexagon/htp/hmx-ops.h
Normal file
@@ -0,0 +1,72 @@
|
||||
// HMX operation entry-point declarations.
|
||||
// Ported from htp-ops-lib/include/dsp/ops.h (renamed, benchmark kernels removed). (https://github.com/haozixu/htp-ops-lib)
|
||||
|
||||
#ifndef HMX_OPS_H
|
||||
#define HMX_OPS_H
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#ifndef restrict
|
||||
# define restrict __restrict
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct htp_context; // forward declaration
|
||||
|
||||
typedef struct {
|
||||
float *dst;
|
||||
const float *activation;
|
||||
const __fp16 *permuted_weight;
|
||||
int m;
|
||||
int k;
|
||||
int n;
|
||||
int act_stride;
|
||||
int weight_stride;
|
||||
int dst_stride;
|
||||
int ne02;
|
||||
int ne03;
|
||||
int ne12;
|
||||
int ne13;
|
||||
size_t src0_nb2;
|
||||
size_t src0_nb3;
|
||||
size_t src1_nb2;
|
||||
size_t src1_nb3;
|
||||
size_t dst_nb2;
|
||||
size_t dst_nb3;
|
||||
} hmx_matmul_w16a32_batched_params_t;
|
||||
|
||||
// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output
|
||||
// act_stride: activation row stride in elements (= k for contiguous, or
|
||||
// nb[1]/sizeof(float) for permuted tensors like attention Q).
|
||||
// weight_stride: weight row stride in elements (= k for compact weights, or
|
||||
// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK).
|
||||
int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx,
|
||||
float *restrict dst,
|
||||
const float *activation,
|
||||
const __fp16 *permuted_weight,
|
||||
int m, int k, int n,
|
||||
int act_stride,
|
||||
int weight_stride);
|
||||
|
||||
// Batched F16 wrapper over hmx_mat_mul_permuted_w16a32.
|
||||
// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3.
|
||||
int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx,
|
||||
const hmx_matmul_w16a32_batched_params_t *params);
|
||||
|
||||
// HMX matrix multiplication — tile-permuted quantised weights (Q4_0/Q8_0/IQ4_NL)
|
||||
int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx,
|
||||
float *restrict dst,
|
||||
const float *activation,
|
||||
const uint8_t *permuted_weight,
|
||||
int m, int k, int n,
|
||||
int weight_type);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // HMX_OPS_H
|
||||
34
ggml/src/ggml-hexagon/htp/hmx-profile.h
Normal file
34
ggml/src/ggml-hexagon/htp/hmx-profile.h
Normal file
@@ -0,0 +1,34 @@
|
||||
// Conditional fine-grained profiling macros for HMX operations.
|
||||
//
|
||||
// Define ENABLE_PROFILE_TIMERS (via compiler flag or before including this
|
||||
// header) to instrument sub-operation latencies with HAP qtimer. When the
|
||||
// macro is not defined the TIMER_* helpers expand to nothing so there is zero
|
||||
// overhead.
|
||||
//
|
||||
// Usage:
|
||||
// TIMER_DEFINE(my_phase); // declare accumulator variable
|
||||
// TIMER_START(my_phase); // snapshot start time
|
||||
// ... work ...
|
||||
// TIMER_STOP(my_phase); // accumulate elapsed ticks
|
||||
// FARF(ALWAYS, "my_phase: %lld us", TIMER_US(my_phase));
|
||||
|
||||
#ifndef HMX_PROFILE_H
|
||||
#define HMX_PROFILE_H
|
||||
|
||||
#include <HAP_perf.h>
|
||||
|
||||
// #define ENABLE_PROFILE_TIMERS
|
||||
|
||||
#if defined(ENABLE_PROFILE_TIMERS)
|
||||
# define TIMER_DEFINE(name) int64_t name##_ticks = 0
|
||||
# define TIMER_START(name) int64_t name##_t0 = HAP_perf_get_qtimer_count()
|
||||
# define TIMER_STOP(name) name##_ticks += HAP_perf_get_qtimer_count() - name##_t0
|
||||
# define TIMER_US(name) HAP_perf_qtimer_count_to_us(name##_ticks)
|
||||
#else
|
||||
# define TIMER_DEFINE(name)
|
||||
# define TIMER_START(name)
|
||||
# define TIMER_STOP(name)
|
||||
# define TIMER_US(name) 0LL
|
||||
#endif
|
||||
|
||||
#endif // HMX_PROFILE_H
|
||||
88
ggml/src/ggml-hexagon/htp/hmx-utils.h
Normal file
88
ggml/src/ggml-hexagon/htp/hmx-utils.h
Normal file
@@ -0,0 +1,88 @@
|
||||
// HMX tile-level inline helpers (FP16 32x32 tile operations).
|
||||
// Ported from htp-ops-lib/include/dsp/hmx_utils.h. (https://github.com/haozixu/htp-ops-lib)
|
||||
|
||||
#ifndef HMX_UTILS_H
|
||||
#define HMX_UTILS_H
|
||||
|
||||
#include <hexagon_types.h>
|
||||
#include <stddef.h>
|
||||
|
||||
#define HMX_FP16_TILE_N_ROWS 32
|
||||
#define HMX_FP16_TILE_N_COLS 32
|
||||
#define HMX_FP16_TILE_N_ELMS 1024
|
||||
#define HMX_FP16_TILE_SIZE 2048
|
||||
|
||||
#define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline))
|
||||
|
||||
static HMX_INLINE_ALWAYS void hmx_set_output_scales(const void *scales) {
|
||||
asm volatile("bias = mxmem2(%0)" :: "r"(scales));
|
||||
}
|
||||
|
||||
// Initialise aligned 256-byte area with scale vector + zero padding.
|
||||
static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) {
|
||||
HVX_Vector *pv = (HVX_Vector *)out_scales;
|
||||
*pv++ = v_scale;
|
||||
*pv = Q6_V_vzero();
|
||||
}
|
||||
|
||||
// Load multiple contiguous tiles with :deep streaming.
|
||||
// Rt = total region size - 1; the hardware streams through [Rs, Rs + Rt].
|
||||
// IMPORTANT: the tile region [Rs, Rs + Rt] must NOT cross a VTCM 4 MB bank
|
||||
// boundary, otherwise the mxmem instruction will raise a precise bus error.
|
||||
// Callers must ensure their VTCM layout satisfies this constraint.
|
||||
static HMX_INLINE_ALWAYS void hmx_load_tiles_fp16(const __fp16 *row_tiles,
|
||||
const __fp16 *col_tiles,
|
||||
size_t n_tiles) {
|
||||
size_t limit = n_tiles * HMX_FP16_TILE_SIZE - 1;
|
||||
asm volatile(
|
||||
"{ activation.hf = mxmem(%0, %1):deep\n"
|
||||
"weight.hf = mxmem(%2, %3) }\n"
|
||||
:: "r"(row_tiles), "r"(limit), "r"(col_tiles), "r"(limit)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
// Load a single activation+weight tile pair (no :deep streaming).
|
||||
// Rt defines the accessible region [Rs, Rs+Rt]. Following the reference formula
|
||||
// (limit = n_tiles * HMX_FP16_TILE_SIZE - 1), for a single tile Rt = 2047.
|
||||
// The original code used Rt=0x7FFF (32 KB region); when dynamic VTCM allocation
|
||||
// places a tile near a 4 MB bank boundary, the oversized region crosses it and
|
||||
// triggers a precise bus error (0x2601). Rt=2047 confines accesses to exactly
|
||||
// one 2048-byte tile while covering all 16 HVX vectors (offsets 0..2047).
|
||||
static HMX_INLINE_ALWAYS void hmx_load_tile_pair_fp16(const __fp16 *act_tile,
|
||||
const __fp16 *wt_tile) {
|
||||
asm volatile(
|
||||
"{ activation.hf = mxmem(%0, %1)\n"
|
||||
"weight.hf = mxmem(%2, %3) }\n"
|
||||
:: "r"(act_tile), "r"(2047),
|
||||
"r"(wt_tile), "r"(2047)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
static HMX_INLINE_ALWAYS void hmx_consume_accumulator_fp16(__fp16 *out) {
|
||||
// Use the combined convert-and-store instruction (matches the reference
|
||||
// Q6_mxmem_AR_after_hf intrinsic). The previous two-instruction sequence
|
||||
// "cvt.hf = acc(2); mxmem = cvt" used an undocumented Rs=2 parameter.
|
||||
asm volatile(
|
||||
"mxmem(%0, %1):after.hf = acc\n"
|
||||
:: "r"(out), "r"(0)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
// Compute inner product of two vectors of tiles and store result.
|
||||
static HMX_INLINE_ALWAYS void hmx_dot_fp16(__fp16 *out,
|
||||
const __fp16 *row_tiles,
|
||||
const __fp16 *col_tiles,
|
||||
size_t n_tiles) {
|
||||
hmx_load_tiles_fp16(row_tiles, col_tiles, n_tiles);
|
||||
hmx_consume_accumulator_fp16(out);
|
||||
}
|
||||
|
||||
// --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) ---
|
||||
|
||||
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {
|
||||
uint8_t *p = *vtcm_ptr;
|
||||
*vtcm_ptr += size;
|
||||
return p;
|
||||
}
|
||||
|
||||
#endif // HMX_UTILS_H
|
||||
@@ -30,6 +30,12 @@ struct htp_context {
|
||||
atomic_bool vtcm_needs_release;
|
||||
|
||||
uint32_t opmask;
|
||||
|
||||
// HMX acceleration fields (v73+, enabled by compile-time HTP_HAS_HMX)
|
||||
#ifdef HTP_HAS_HMX
|
||||
int hmx_enabled; // Runtime flag: HMX initialisation succeeded
|
||||
size_t vtcm_scratch_size; // Usable dynamic scratch (vtcm_size minus tail reservation)
|
||||
#endif
|
||||
};
|
||||
|
||||
#endif /* HTP_CTX_H */
|
||||
|
||||
@@ -32,13 +32,14 @@ enum htp_status {
|
||||
// Duplicated here because we can't include full ggml.h in the htp build.
|
||||
// We have some static_asserts in the cpp code to ensure things are in sync.
|
||||
enum htp_data_type {
|
||||
HTP_TYPE_F32 = 0,
|
||||
HTP_TYPE_F16 = 1,
|
||||
HTP_TYPE_Q4_0 = 2,
|
||||
HTP_TYPE_Q8_0 = 8,
|
||||
HTP_TYPE_I32 = 26,
|
||||
HTP_TYPE_I64 = 27,
|
||||
HTP_TYPE_MXFP4 = 39,
|
||||
HTP_TYPE_F32 = 0,
|
||||
HTP_TYPE_F16 = 1,
|
||||
HTP_TYPE_Q4_0 = 2,
|
||||
HTP_TYPE_Q8_0 = 8,
|
||||
HTP_TYPE_IQ4_NL = 20,
|
||||
HTP_TYPE_I32 = 26,
|
||||
HTP_TYPE_I64 = 27,
|
||||
HTP_TYPE_MXFP4 = 39,
|
||||
HTP_TYPE_COUNT
|
||||
};
|
||||
|
||||
@@ -87,6 +88,8 @@ static inline size_t htp_t_block_size(uint32_t t) {
|
||||
return QK4_0;
|
||||
case HTP_TYPE_Q8_0:
|
||||
return QK8_0;
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
return QK4_NL;
|
||||
case HTP_TYPE_MXFP4:
|
||||
return QK_MXFP4;
|
||||
default:
|
||||
@@ -105,6 +108,8 @@ static inline size_t htp_type_nbytes(uint32_t t) {
|
||||
return sizeof(block_q4_0);
|
||||
case HTP_TYPE_Q8_0:
|
||||
return sizeof(block_q8_0);
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
return sizeof(block_iq4_nl);
|
||||
case HTP_TYPE_MXFP4:
|
||||
return sizeof(block_mxfp4);
|
||||
default:
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include "remote.idl"
|
||||
|
||||
interface htp_iface : remote_handle64 {
|
||||
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx);
|
||||
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx);
|
||||
AEEResult stop();
|
||||
AEEResult enable_etm();
|
||||
AEEResult disable_etm();
|
||||
|
||||
@@ -9,6 +9,9 @@
|
||||
#include "hex-utils.h"
|
||||
#include "hvx-types.h"
|
||||
|
||||
#define hvx_vmem(A) *((HVX_Vector *)(A))
|
||||
#define hvx_vmemu(A) *((HVX_UVector *)(A))
|
||||
|
||||
static inline void hvx_vec_store_u(void * restrict dst, uint32_t n, HVX_Vector v) {
|
||||
// Rotate as needed.
|
||||
v = Q6_V_vlalign_VVR(v, v, (size_t) dst);
|
||||
@@ -112,11 +115,15 @@ static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) {
|
||||
return Q6_Q_and_QQ(p_exp, p_frac);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1) {
|
||||
const HVX_Vector zero = Q6_V_vzero();
|
||||
HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero);
|
||||
HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero);
|
||||
HVX_Vector v = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0)));
|
||||
return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0));
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
|
||||
HVX_Vector v = Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1));
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
// replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0)
|
||||
@@ -128,6 +135,30 @@ static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
|
||||
return v;
|
||||
}
|
||||
|
||||
#if __HVX_ARCH__ >= 79
|
||||
static inline HVX_VectorPair hvx_vec_f16_to_f32_shuff(HVX_Vector v) {
|
||||
const HVX_Vector one = hvx_vec_splat_f16(1.0);
|
||||
HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(v, one);
|
||||
return Q6_W_vcombine_VV(Q6_V_hi_W(p), Q6_V_lo_W(p));
|
||||
}
|
||||
static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
|
||||
const HVX_Vector one = hvx_vec_splat_f16(1.0);
|
||||
HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(v), one);
|
||||
return Q6_W_vcombine_VV(Q6_V_hi_W(p), Q6_V_lo_W(p));
|
||||
}
|
||||
#else
|
||||
static inline HVX_VectorPair hvx_vec_f16_to_f32_shuff(HVX_Vector v) {
|
||||
const HVX_Vector one = hvx_vec_splat_f16(1.0);
|
||||
HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(v, one);
|
||||
return Q6_W_vcombine_VV(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)));
|
||||
}
|
||||
static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
|
||||
const HVX_Vector one = hvx_vec_splat_f16(1.0);
|
||||
HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(v), one);
|
||||
return Q6_W_vcombine_VV(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)));
|
||||
}
|
||||
#endif
|
||||
|
||||
/* Q6_Vsf_equals_Vw is only available on v73+.*/
|
||||
#if __HVX_ARCH__ < 73
|
||||
static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in)
|
||||
|
||||
@@ -25,6 +25,10 @@
|
||||
#include "htp-ops.h"
|
||||
#include "worker-pool.h"
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
#include "hmx-ops.h"
|
||||
#endif // HTP_HAS_HMX
|
||||
|
||||
AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
|
||||
struct htp_context * ctx;
|
||||
int err = 0;
|
||||
@@ -163,6 +167,9 @@ static int vtcm_acquire(struct htp_context * ctx) {
|
||||
}
|
||||
|
||||
ctx->vtcm_inuse = true;
|
||||
|
||||
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -246,7 +253,7 @@ static void vtcm_free(struct htp_context * ctx) {
|
||||
static void htp_packet_callback(dspqueue_t queue, int error, void * context);
|
||||
static void htp_error_callback(dspqueue_t queue, int error, void * context);
|
||||
|
||||
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx) {
|
||||
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx) {
|
||||
struct htp_context * ctx = (struct htp_context *) handle;
|
||||
|
||||
if (!ctx) {
|
||||
@@ -280,6 +287,21 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
|
||||
return AEE_ENOMEMORY;
|
||||
}
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
if (use_hmx) {
|
||||
ctx->vtcm_scratch_size = ctx->vtcm_size;
|
||||
ctx->hmx_enabled = 1;
|
||||
|
||||
FARF(HIGH, "HMX enabled: vtcm-scratch %zu", ctx->vtcm_scratch_size);
|
||||
} else {
|
||||
// HMX disabled: skip HMX initialisation so the
|
||||
// dispatch loop falls through to the HVX compute paths.
|
||||
ctx->hmx_enabled = 0;
|
||||
ctx->vtcm_scratch_size = ctx->vtcm_size;
|
||||
FARF(HIGH, "HMX disabled (use_hmx=0): vtcm-scratch %zu", ctx->vtcm_scratch_size);
|
||||
}
|
||||
#endif
|
||||
|
||||
qurt_sysenv_max_hthreads_t hw_threads;
|
||||
qurt_sysenv_get_max_hw_threads(&hw_threads);
|
||||
uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF;
|
||||
@@ -340,6 +362,12 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
|
||||
for (int i = 0; i < ctx->n_threads; i++) {
|
||||
dma_queue_delete(ctx->dma[i]);
|
||||
}
|
||||
#ifdef HTP_HAS_HMX
|
||||
if (ctx->hmx_enabled) {
|
||||
ctx->hmx_enabled = 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
vtcm_free(ctx);
|
||||
|
||||
@@ -375,8 +403,9 @@ static int send_htp_rsp(struct htp_context * c,
|
||||
struct dspqueue_buffer * bufs,
|
||||
size_t n_bufs,
|
||||
struct profile_data * prof) {
|
||||
// Prep response struct
|
||||
// Prep response struct (zero-init to clear cmp/unused union)
|
||||
struct htp_general_rsp rsp;
|
||||
memset(&rsp, 0, sizeof(rsp));
|
||||
rsp.op = op;
|
||||
rsp.status = status;
|
||||
rsp.prof_usecs = prof->usecs;
|
||||
@@ -1037,6 +1066,210 @@ static void proc_flash_attn_ext_req(struct htp_context * ctx,
|
||||
send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof);
|
||||
}
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
// ---------------------------------------------------------------------------
|
||||
// HMX operation wrappers — self-contained, bypass htp_ops_context / htp_spad.
|
||||
// VTCM, DMA and thread dispatch are managed inside the HMX kernels.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static void proc_hmx_matmul_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
size_t n_bufs) {
|
||||
// HMX weight tile requires N to be 32-aligned.
|
||||
if (req->src0.ne[1] % 32 != 0) {
|
||||
proc_matmul_req(ctx, req, bufs, n_bufs);
|
||||
return;
|
||||
}
|
||||
|
||||
const bool is_batched = (req->src0.ne[2] * req->src0.ne[3] > 1 ||
|
||||
req->src1.ne[2] * req->src1.ne[3] > 1);
|
||||
|
||||
// Quantised HMX kernels only handle flat 2D matmul (host already rejects
|
||||
// batched quantised, but guard here too). F16 batched matmul is handled
|
||||
// by the dedicated wrapper in hmx-matmul-ops.c.
|
||||
if (is_batched &&
|
||||
req->src0.type != HTP_TYPE_F16) {
|
||||
proc_matmul_req(ctx, req, bufs, n_bufs);
|
||||
return;
|
||||
}
|
||||
|
||||
// HMX assumes contiguous row-major layout. Fall back for permuted
|
||||
// tensors where strides are non-monotonic (e.g. transposed KV cache).
|
||||
if (req->src0.nb[0] > req->src0.nb[1] ||
|
||||
req->src1.nb[0] > req->src1.nb[1]) {
|
||||
proc_matmul_req(ctx, req, bufs, n_bufs);
|
||||
return;
|
||||
}
|
||||
|
||||
// M alignment: when M > 32 but not 32-aligned, we split into
|
||||
// HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows).
|
||||
// When M <= 32 and not 32-aligned, fall back entirely to HVX.
|
||||
const int m_total = (int) req->src1.ne[1];
|
||||
const int m_tail = m_total % 32;
|
||||
const int m_hmx = m_total - m_tail;
|
||||
|
||||
if (m_hmx == 0) {
|
||||
proc_matmul_req(ctx, req, bufs, n_bufs);
|
||||
return;
|
||||
}
|
||||
|
||||
// HMX only supports F16, Q4_0, Q8_0, IQ4_NL weights.
|
||||
// Other types (e.g. MXFP4) fall back to HVX.
|
||||
{
|
||||
uint32_t wtype = req->src0.type;
|
||||
if (wtype != HTP_TYPE_F16 &&
|
||||
wtype != HTP_TYPE_Q4_0 &&
|
||||
wtype != HTP_TYPE_Q8_0 &&
|
||||
wtype != HTP_TYPE_IQ4_NL) {
|
||||
proc_matmul_req(ctx, req, bufs, n_bufs);
|
||||
return;
|
||||
}
|
||||
// Quantised HMX path requires K aligned to 256 (x4x2 super-block).
|
||||
// F16 HMX path requires K aligned to 32 (tile width).
|
||||
if (wtype != HTP_TYPE_F16 && req->src0.ne[0] % 256 != 0) {
|
||||
proc_matmul_req(ctx, req, bufs, n_bufs);
|
||||
return;
|
||||
}
|
||||
if (wtype == HTP_TYPE_F16 && req->src0.ne[0] % 32 != 0) {
|
||||
proc_matmul_req(ctx, req, bufs, n_bufs);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
(void) n_bufs;
|
||||
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
rsp_bufs[0].fd = bufs[2].fd;
|
||||
rsp_bufs[0].ptr = bufs[2].ptr;
|
||||
rsp_bufs[0].size = bufs[2].size;
|
||||
rsp_bufs[0].offset = bufs[2].offset;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);
|
||||
|
||||
// src0 = weights, src1 = activation, dst = output
|
||||
void * wgt = (void *) bufs[0].ptr;
|
||||
float * act = (float *) bufs[1].ptr;
|
||||
float * dst = (float *) bufs[2].ptr;
|
||||
|
||||
int k = (int) req->src0.ne[0]; // inner dimension
|
||||
int n = (int) req->src0.ne[1]; // weight columns
|
||||
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
|
||||
// --- Phase 1: HMX on the first m_hmx (32-aligned) rows ---
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
int ret = -1;
|
||||
|
||||
const int ne02 = (int) req->src0.ne[2];
|
||||
const int ne03 = (int) req->src0.ne[3];
|
||||
const int ne12 = (int) req->src1.ne[2];
|
||||
const int ne13 = (int) req->src1.ne[3];
|
||||
// Row strides in elements. For compact tensors these equal k; for
|
||||
// permuted attention views they can be larger, so pass the real stride.
|
||||
const int act_stride = (int)(req->src1.nb[1] / sizeof(float));
|
||||
const int weight_stride = (int)(req->src0.nb[1] / sizeof(__fp16));
|
||||
|
||||
switch (req->src0.type) {
|
||||
case HTP_TYPE_F16:
|
||||
if (is_batched) {
|
||||
hmx_matmul_w16a32_batched_params_t batch_params = {
|
||||
.dst = dst,
|
||||
.activation = act,
|
||||
.permuted_weight = (const __fp16 *) wgt,
|
||||
.m = m_hmx,
|
||||
.k = k,
|
||||
.n = n,
|
||||
.act_stride = act_stride,
|
||||
.weight_stride = weight_stride,
|
||||
.dst_stride = (int)(req->dst.nb[1] / sizeof(float)),
|
||||
.ne02 = ne02,
|
||||
.ne03 = ne03,
|
||||
.ne12 = ne12,
|
||||
.ne13 = ne13,
|
||||
.src0_nb2 = req->src0.nb[2],
|
||||
.src0_nb3 = req->src0.nb[3],
|
||||
.src1_nb2 = req->src1.nb[2],
|
||||
.src1_nb3 = req->src1.nb[3],
|
||||
.dst_nb2 = req->dst.nb[2],
|
||||
.dst_nb3 = req->dst.nb[3],
|
||||
};
|
||||
ret = hmx_mat_mul_permuted_w16a32_batched(ctx, &batch_params);
|
||||
} else {
|
||||
ret = hmx_mat_mul_permuted_w16a32(ctx, dst, act,
|
||||
(const __fp16 *) wgt,
|
||||
m_hmx, k, n,
|
||||
act_stride,
|
||||
weight_stride);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
ret = hmx_mat_mul_permuted_qk_0_d16a32(ctx, dst, act,
|
||||
(const uint8_t *) wgt,
|
||||
m_hmx, k, n, (int) req->src0.type);
|
||||
break;
|
||||
}
|
||||
|
||||
if (ret == 0) {
|
||||
rsp_status = HTP_STATUS_OK;
|
||||
} else {
|
||||
FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret);
|
||||
vtcm_release(ctx);
|
||||
req->flags &= ~HTP_OPFLAGS_SKIP_QUANTIZE;
|
||||
proc_matmul_req(ctx, req, bufs, n_bufs);
|
||||
return;
|
||||
}
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
// --- Phase 2: HVX on the remaining m_tail rows ---
|
||||
if (m_tail > 0 && rsp_status == HTP_STATUS_OK) {
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0; // weights: unchanged
|
||||
octx.src1 = req->src1;
|
||||
octx.src1.ne[1] = m_tail; // only tail rows
|
||||
octx.dst = req->dst;
|
||||
octx.dst.ne[1] = m_tail; // only tail rows
|
||||
// Always re-quantize tail src1: HMX Phase 1 overwrites VTCM,
|
||||
// so any previously cached quantized data (SKIP_QUANTIZE pipeline)
|
||||
// is invalid.
|
||||
octx.flags = req->flags & ~HTP_OPFLAGS_SKIP_QUANTIZE;
|
||||
octx.op = req->op;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
// Offset activation and dst pointers past the HMX-processed rows.
|
||||
// Use nb[1] (row stride in bytes) to compute the byte offset.
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.src1.data = (uint32_t)((uint8_t *) bufs[1].ptr + (size_t) m_hmx * req->src1.nb[1]);
|
||||
octx.dst.data = (uint32_t)((uint8_t *) bufs[2].ptr + (size_t) m_hmx * req->dst.nb[1]);
|
||||
|
||||
FARF(HIGH, "proc_hmx_matmul: HVX tail m_tail=%d act=%p dst=%p",
|
||||
m_tail, (void *)(uintptr_t) octx.src1.data, (void *)(uintptr_t) octx.dst.data);
|
||||
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
uint32_t hvx_ret = op_matmul(&octx);
|
||||
vtcm_release(ctx);
|
||||
if (hvx_ret != HTP_STATUS_OK) {
|
||||
FARF(ERROR, "HVX tail matmul failed (ret=%u)", hvx_ret);
|
||||
rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
}
|
||||
} else {
|
||||
rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
}
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
#endif // HTP_HAS_HMX
|
||||
|
||||
static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
struct htp_context * ctx = (struct htp_context *) context;
|
||||
|
||||
@@ -1089,7 +1322,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
FARF(ERROR, "Bad matmul-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_matmul_req(ctx, &req, bufs, n_bufs);
|
||||
#ifdef HTP_HAS_HMX
|
||||
if (ctx->hmx_enabled) {
|
||||
proc_hmx_matmul_req(ctx, &req, bufs, n_bufs);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
proc_matmul_req(ctx, &req, bufs, n_bufs);
|
||||
}
|
||||
break;
|
||||
|
||||
case HTP_OP_MUL_MAT_ID:
|
||||
|
||||
@@ -53,9 +53,6 @@ endif()
|
||||
|
||||
message(STATUS "HIP and hipBLAS found")
|
||||
|
||||
# Workaround old compilers
|
||||
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} --gpu-max-threads-per-block=1024")
|
||||
|
||||
file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
|
||||
list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
|
||||
|
||||
@@ -74,12 +71,11 @@ if (GGML_CUDA_FA_ALL_QUANTS)
|
||||
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
||||
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
|
||||
else()
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
|
||||
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
|
||||
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
|
||||
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
||||
list(APPEND GGML_SOURCES_ROCM
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu)
|
||||
endif()
|
||||
|
||||
ggml_add_backend_library(ggml-hip
|
||||
@@ -132,6 +128,11 @@ endif()
|
||||
|
||||
if (CXX_IS_HIPCC)
|
||||
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
|
||||
if (WIN32 AND CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
# CMake on Windows doesn't support the HIP language yet.
|
||||
# Therefore we workaround debug build's failure on HIP backend this way.
|
||||
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES COMPILE_FLAGS "-O2 -g")
|
||||
endif()
|
||||
target_link_libraries(ggml-hip PRIVATE hip::device)
|
||||
else()
|
||||
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE HIP)
|
||||
|
||||
@@ -48,12 +48,11 @@ if (MUSAToolkit_FOUND)
|
||||
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
||||
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
|
||||
else()
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
|
||||
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
|
||||
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
|
||||
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
||||
list(APPEND GGML_SOURCES_MUSA
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu)
|
||||
endif()
|
||||
|
||||
set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX)
|
||||
|
||||
@@ -1162,12 +1162,18 @@ ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rp
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Fix: Prevent division by zero if blck_size is 0 (e.g., deprecated types)
|
||||
if (ggml_blck_size((enum ggml_type)tensor->type) == 0) {
|
||||
GGML_LOG_ERROR("[%s] invalid tensor type received (blck_size is 0): %u\n", __func__, tensor->type);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
|
||||
tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
||||
|
||||
// ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type
|
||||
if (result == nullptr) {
|
||||
GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type);
|
||||
GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\n", __func__, tensor->type);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
@@ -4667,22 +4667,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (a->ne[3] != b->ne[3]) {
|
||||
return false;
|
||||
}
|
||||
ggml_type a_type = a->type;
|
||||
if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
|
||||
a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
|
||||
a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S ||
|
||||
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M
|
||||
) {
|
||||
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
if (src0_type == GGML_TYPE_BF16 ) {
|
||||
// TODO: support GGML_TYPE_BF16
|
||||
// FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: The configuration below needs more work to be supported with oneDNN
|
||||
if (ggml_is_permuted(a) && !ggml_is_contiguous(a) &&
|
||||
|
||||
@@ -4604,12 +4604,42 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
{"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"},
|
||||
{"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"},
|
||||
};
|
||||
const bool use_subgroup_reduce = device->subgroup_arithmetic;
|
||||
for (uint32_t si = 0; si < 3; si++) {
|
||||
const uint32_t S_V = gdn_sizes[si];
|
||||
GGML_ASSERT(is_pow2(S_V));
|
||||
|
||||
uint32_t lanes_per_column;
|
||||
if (S_V >= 128u && device->subgroup_clustered) {
|
||||
lanes_per_column = 8u;
|
||||
} else {
|
||||
// Use largest power-of-two that divides both S_V and subgroup_size so that
|
||||
// (1) S_V % lanes_per_column == 0 and (2) S_V % (subgroup_size / lanes_per_column) == 0.
|
||||
// This means we don't need extra bounds checking logic in the shader.
|
||||
lanes_per_column = std::min(S_V, device->subgroup_size);
|
||||
}
|
||||
|
||||
const bool need_clustered_shader = lanes_per_column != 1 && (lanes_per_column < device->subgroup_size);
|
||||
size_t gdn_len;
|
||||
const void * gdn_data;
|
||||
if (use_subgroup_reduce && need_clustered_shader) {
|
||||
gdn_len = gated_delta_net_f32_len;
|
||||
gdn_data = (const void *)gated_delta_net_f32_data;
|
||||
} else if (use_subgroup_reduce) {
|
||||
gdn_len = gated_delta_net_f32_nocluster_len;
|
||||
gdn_data = (const void *)gated_delta_net_f32_nocluster_data;
|
||||
} else {
|
||||
gdn_len = gated_delta_net_f32_shmem_len;
|
||||
gdn_data = (const void *)gated_delta_net_f32_shmem_data;
|
||||
}
|
||||
|
||||
const uint32_t cols_per_wg = device->subgroup_size / lanes_per_column;
|
||||
const std::array<uint32_t, 3> wg_denoms = {1u, 1u, cols_per_wg};
|
||||
|
||||
for (uint32_t kda = 0; kda < 2; kda++) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda],
|
||||
gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data,
|
||||
"main", 7, sizeof(vk_op_gated_delta_net_push_constants),
|
||||
{1, 1, 1}, {gdn_sizes[si], kda}, 1);
|
||||
gdn_names[si][kda], gdn_len, gdn_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants),
|
||||
wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_reduce, device->subgroup_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -10438,7 +10468,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
|
||||
pc, { H, n_seqs, 1u });
|
||||
pc, { H, n_seqs, S_v });
|
||||
}
|
||||
|
||||
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
||||
@@ -16018,6 +16048,7 @@ static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev)
|
||||
case 0xE20C: // B570
|
||||
return 18;
|
||||
case 0xE20B: // B580
|
||||
case 0xE211: // Pro B60
|
||||
return 20;
|
||||
default:
|
||||
return 0;
|
||||
|
||||
@@ -1,11 +1,25 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#if USE_SUBGROUP_CLUSTERED
|
||||
#extension GL_KHR_shader_subgroup_clustered : enable
|
||||
#endif
|
||||
#if USE_SUBGROUP_ADD
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#endif
|
||||
|
||||
// Caller guarantees valid spec constants: S_V % COLS_PER_WG == 0 and S_V % LANES_PER_COLUMN == 0,
|
||||
// so no bounds checking is needed.
|
||||
layout(constant_id = 0) const uint S_V = 128;
|
||||
layout(constant_id = 1) const uint KDA = 0;
|
||||
layout(constant_id = 2) const uint SUBGROUP_SIZE = 32;
|
||||
layout(constant_id = 3) const uint LANES_PER_COLUMN = 32;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
const uint COLS_PER_WG = SUBGROUP_SIZE / LANES_PER_COLUMN;
|
||||
const uint ROWS_PER_LANE = S_V / LANES_PER_COLUMN;
|
||||
|
||||
layout(local_size_x_id = 2, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout(push_constant) uniform Parameters {
|
||||
uint H;
|
||||
@@ -27,14 +41,61 @@ layout(binding = 4) readonly buffer BetaBuf { FLOAT_TYPE data_beta[]; };
|
||||
layout(binding = 5) readonly buffer StateBuf { FLOAT_TYPE data_state[]; };
|
||||
layout(binding = 6) buffer DstBuf { FLOAT_TYPE data_dst[]; };
|
||||
|
||||
shared FLOAT_TYPE s_k[S_V];
|
||||
shared FLOAT_TYPE s_q[S_V];
|
||||
shared FLOAT_TYPE s_g[S_V]; // KDA only: cached exp(g[i])
|
||||
#if !USE_SUBGROUP_ADD && !USE_SUBGROUP_CLUSTERED
|
||||
shared FLOAT_TYPE temp[SUBGROUP_SIZE];
|
||||
|
||||
// This does a reduction across groups of LANES_PER_COLUMN
|
||||
FLOAT_TYPE reduce_add_shmem(FLOAT_TYPE partial) {
|
||||
const uint lane = gl_SubgroupInvocationID;
|
||||
temp[lane] = partial;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = LANES_PER_COLUMN / 2u; s > 0; s >>= 1u) {
|
||||
FLOAT_TYPE other = temp[lane ^ s];
|
||||
barrier();
|
||||
temp[lane] += other;
|
||||
barrier();
|
||||
}
|
||||
const FLOAT_TYPE result = temp[lane];
|
||||
barrier();
|
||||
return result;
|
||||
}
|
||||
#endif
|
||||
|
||||
// clusterSize for subgroupClusteredAdd must be a compile-time constant; branch on spec constant
|
||||
FLOAT_TYPE reduce_partial(FLOAT_TYPE partial) {
|
||||
switch (LANES_PER_COLUMN) {
|
||||
case 1u:
|
||||
return partial;
|
||||
#if USE_SUBGROUP_CLUSTERED
|
||||
// Workaround for GLSL requiring a literal constant for the cluster size.
|
||||
// The branches should all fold away.
|
||||
case 2u:
|
||||
return subgroupClusteredAdd(partial, 2u);
|
||||
case 4u:
|
||||
return subgroupClusteredAdd(partial, 4u);
|
||||
case 8u:
|
||||
return subgroupClusteredAdd(partial, 8u);
|
||||
case 16u:
|
||||
return subgroupClusteredAdd(partial, 16u);
|
||||
case 32u:
|
||||
return subgroupClusteredAdd(partial, 32u);
|
||||
case 64u:
|
||||
return subgroupClusteredAdd(partial, 64u);
|
||||
#endif
|
||||
default:
|
||||
#if USE_SUBGROUP_ADD
|
||||
return subgroupAdd(partial);
|
||||
#else
|
||||
return reduce_add_shmem(partial);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint head_id = gl_WorkGroupID.x;
|
||||
const uint seq_id = gl_WorkGroupID.y;
|
||||
const uint col = gl_LocalInvocationID.x;
|
||||
const uint seq_id = gl_WorkGroupID.y;
|
||||
const uint lane = gl_SubgroupInvocationID % LANES_PER_COLUMN;
|
||||
const uint col = gl_WorkGroupID.z * COLS_PER_WG + (gl_SubgroupInvocationID / LANES_PER_COLUMN);
|
||||
|
||||
const uint iq1 = head_id % neq1;
|
||||
const uint iq3 = seq_id / rq3;
|
||||
@@ -42,9 +103,9 @@ void main() {
|
||||
const uint state_size = S_V * S_V;
|
||||
const uint state_base = (seq_id * H + head_id) * state_size;
|
||||
|
||||
FLOAT_TYPE state[S_V];
|
||||
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
||||
state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]);
|
||||
FLOAT_TYPE s_shard[ROWS_PER_LANE];
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]);
|
||||
}
|
||||
|
||||
uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
|
||||
@@ -53,76 +114,56 @@ void main() {
|
||||
const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1;
|
||||
const uint k_off = q_off;
|
||||
const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1;
|
||||
|
||||
s_q[col] = FLOAT_TYPE(data_q[q_off + col]);
|
||||
s_k[col] = FLOAT_TYPE(data_k[k_off + col]);
|
||||
|
||||
const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1;
|
||||
|
||||
if (KDA != 0) {
|
||||
const uint g_base = gb_off * S_V;
|
||||
s_g[col] = exp(FLOAT_TYPE(data_g[g_base + col]));
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
|
||||
const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]);
|
||||
|
||||
FLOAT_TYPE k_reg[ROWS_PER_LANE];
|
||||
FLOAT_TYPE q_reg[ROWS_PER_LANE];
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
const uint i = r * LANES_PER_COLUMN + lane;
|
||||
k_reg[r] = FLOAT_TYPE(data_k[k_off + i]);
|
||||
q_reg[r] = FLOAT_TYPE(data_q[q_off + i]);
|
||||
}
|
||||
|
||||
FLOAT_TYPE g_exp[ROWS_PER_LANE];
|
||||
if (KDA == 0) {
|
||||
const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off]));
|
||||
|
||||
FLOAT_TYPE kv_col = 0.0;
|
||||
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
|
||||
kv_col += dot(
|
||||
vec4(state[i], state[i+1], state[i+2], state[i+3]),
|
||||
vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3])
|
||||
);
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
g_exp[r] = g_val;
|
||||
}
|
||||
|
||||
FLOAT_TYPE delta_col = (v_val - g_val * kv_col) * beta_val;
|
||||
|
||||
FLOAT_TYPE attn_col = 0.0;
|
||||
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
|
||||
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
|
||||
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
|
||||
sv = g_val * sv + kv * delta_col;
|
||||
state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
|
||||
|
||||
attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
|
||||
}
|
||||
|
||||
data_dst[attn_off + col] = attn_col * scale;
|
||||
} else {
|
||||
FLOAT_TYPE kv_col = 0.0;
|
||||
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
|
||||
vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
|
||||
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
|
||||
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
|
||||
kv_col += dot(gv * sv, kv);
|
||||
const uint g_base = gb_off * S_V;
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
const uint i = r * LANES_PER_COLUMN + lane;
|
||||
g_exp[r] = exp(FLOAT_TYPE(data_g[g_base + i]));
|
||||
}
|
||||
}
|
||||
|
||||
FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
|
||||
const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
|
||||
|
||||
FLOAT_TYPE attn_col = 0.0;
|
||||
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
|
||||
vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
|
||||
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
|
||||
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
|
||||
sv = gv * sv + kv * delta_col;
|
||||
state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
|
||||
FLOAT_TYPE kv_shard = 0.0;
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
kv_shard += g_exp[r] * s_shard[r] * k_reg[r];
|
||||
}
|
||||
FLOAT_TYPE kv_col = reduce_partial(kv_shard);
|
||||
|
||||
attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
|
||||
}
|
||||
FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
|
||||
|
||||
FLOAT_TYPE attn_partial = 0.0;
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
s_shard[r] = g_exp[r] * s_shard[r] + k_reg[r] * delta_col;
|
||||
attn_partial += s_shard[r] * q_reg[r];
|
||||
}
|
||||
FLOAT_TYPE attn_col = reduce_partial(attn_partial);
|
||||
|
||||
if (lane == 0) {
|
||||
data_dst[attn_off + col] = attn_col * scale;
|
||||
}
|
||||
|
||||
attn_off += S_V * H;
|
||||
barrier();
|
||||
}
|
||||
|
||||
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
||||
data_dst[s_off + state_base + col * S_V + i] = state[i];
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -987,7 +987,9 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}}));
|
||||
string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "1"}}));
|
||||
string_to_spv("gated_delta_net_f32_nocluster", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "0"}}));
|
||||
string_to_spv("gated_delta_net_f32_shmem", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "0"}, {"USE_SUBGROUP_CLUSTERED", "0"}}));
|
||||
|
||||
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
@@ -95,6 +95,11 @@ struct ggml_webgpu_generic_shader_decisions {
|
||||
uint32_t wg_size = 0;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_ssm_conv_shader_decisions {
|
||||
uint32_t block_size;
|
||||
uint32_t tokens_per_wg;
|
||||
};
|
||||
|
||||
/** Argsort **/
|
||||
|
||||
struct ggml_webgpu_argsort_shader_lib_context {
|
||||
@@ -131,6 +136,26 @@ struct ggml_webgpu_set_rows_shader_decisions {
|
||||
uint32_t wg_size;
|
||||
};
|
||||
|
||||
/** Set **/
|
||||
|
||||
struct ggml_webgpu_set_pipeline_key {
|
||||
ggml_type type;
|
||||
bool inplace;
|
||||
|
||||
bool operator==(const ggml_webgpu_set_pipeline_key & other) const {
|
||||
return type == other.type && inplace == other.inplace;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_set_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_set_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** Get Rows **/
|
||||
|
||||
struct ggml_webgpu_get_rows_pipeline_key {
|
||||
@@ -186,6 +211,67 @@ struct ggml_webgpu_pad_pipeline_key_hash {
|
||||
}
|
||||
};
|
||||
|
||||
/** Solve Tri **/
|
||||
struct ggml_webgpu_solve_tri_pipeline_key {
|
||||
int type;
|
||||
int n;
|
||||
int k;
|
||||
|
||||
bool operator==(const ggml_webgpu_solve_tri_pipeline_key & other) const {
|
||||
return type == other.type && n == other.n && k == other.k;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_solve_tri_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_solve_tri_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.n);
|
||||
ggml_webgpu_hash_combine(seed, key.k);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** SSM Conv **/
|
||||
struct ggml_webgpu_ssm_conv_pipeline_key {
|
||||
int type;
|
||||
int vectorized;
|
||||
|
||||
bool operator==(const ggml_webgpu_ssm_conv_pipeline_key & other) const {
|
||||
return type == other.type && vectorized == other.vectorized;
|
||||
}
|
||||
};
|
||||
|
||||
/** Gated Delta Net **/
|
||||
struct ggml_webgpu_gated_delta_net_pipeline_key {
|
||||
int type;
|
||||
int s_v;
|
||||
int kda;
|
||||
|
||||
bool operator==(const ggml_webgpu_gated_delta_net_pipeline_key & other) const {
|
||||
return type == other.type && s_v == other.s_v && kda == other.kda;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_gated_delta_net_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_gated_delta_net_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.s_v);
|
||||
ggml_webgpu_hash_combine(seed, key.kda);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_ssm_conv_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_ssm_conv_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** Scale **/
|
||||
|
||||
struct ggml_webgpu_scale_pipeline_key {
|
||||
@@ -466,14 +552,22 @@ class ggml_webgpu_shader_lib {
|
||||
unary_pipelines; // type/op/inplace
|
||||
std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
|
||||
scale_pipelines; // inplace
|
||||
std::unordered_map<ggml_webgpu_solve_tri_pipeline_key, webgpu_pipeline, ggml_webgpu_solve_tri_pipeline_key_hash>
|
||||
solve_tri_pipelines; // type
|
||||
std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash>
|
||||
ssm_conv_pipelines; // type/vectorized
|
||||
std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_gated_delta_net_pipeline_key_hash>
|
||||
gated_delta_net_pipelines; // type/S_v/kda
|
||||
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
|
||||
pad_pipelines; // circular/non-circular
|
||||
pad_pipelines; // circular/non-circular
|
||||
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
|
||||
binary_pipelines; // type/op/inplace/overlap
|
||||
binary_pipelines; // type/op/inplace/overlap
|
||||
std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
|
||||
concat_pipelines; // type
|
||||
concat_pipelines; // type
|
||||
std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
|
||||
repeat_pipelines; // type
|
||||
repeat_pipelines; // type
|
||||
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
||||
flash_attn_pipelines;
|
||||
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
|
||||
@@ -487,6 +581,7 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
|
||||
set_rows_pipelines;
|
||||
std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines;
|
||||
|
||||
public:
|
||||
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
|
||||
@@ -519,11 +614,11 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
switch (key.op) {
|
||||
case GGML_OP_RMS_NORM:
|
||||
defines.push_back("OP_RMS_NORM");
|
||||
defines.push_back("RMS_NORM");
|
||||
variant = "rms_norm";
|
||||
break;
|
||||
case GGML_OP_L2_NORM:
|
||||
defines.push_back("OP_L2_NORM");
|
||||
defines.push_back("L2_NORM");
|
||||
variant = "l2_norm";
|
||||
break;
|
||||
default:
|
||||
@@ -535,8 +630,9 @@ class ggml_webgpu_shader_lib {
|
||||
variant += "_inplace";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
const uint32_t row_norm_wg_size = 128u;
|
||||
uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size);
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
|
||||
row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
return row_norm_pipelines[key];
|
||||
@@ -609,6 +705,46 @@ class ggml_webgpu_shader_lib {
|
||||
return set_rows_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_set_pipeline_key key = { .type = context.dst->type, .inplace = context.inplace };
|
||||
|
||||
auto it = set_pipelines.find(key);
|
||||
if (it != set_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "set";
|
||||
|
||||
switch (key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("TYPE_F32");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_I32:
|
||||
defines.push_back("TYPE_I32");
|
||||
variant += "_i32";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for set shader");
|
||||
}
|
||||
|
||||
if (key.inplace) {
|
||||
defines.push_back("INPLACE");
|
||||
variant += "_inplace";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_set, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
set_pipelines[key] = pipeline;
|
||||
return set_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
auto it = cumsum_pipelines.find(1);
|
||||
if (it != cumsum_pipelines.end()) {
|
||||
@@ -695,6 +831,7 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
switch (key.src_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("FLOAT_PARALLEL");
|
||||
if (key.vectorized) {
|
||||
defines.push_back("F32_VEC");
|
||||
defines.push_back("SRC_TYPE=vec4<f32>");
|
||||
@@ -709,6 +846,7 @@ class ggml_webgpu_shader_lib {
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("FLOAT_PARALLEL");
|
||||
defines.push_back("F16");
|
||||
defines.push_back("SRC_TYPE=f16");
|
||||
defines.push_back("DST_TYPE=f32");
|
||||
@@ -716,6 +854,7 @@ class ggml_webgpu_shader_lib {
|
||||
variant += "_f16";
|
||||
break;
|
||||
case GGML_TYPE_I32:
|
||||
defines.push_back("FLOAT_PARALLEL");
|
||||
defines.push_back("I32");
|
||||
defines.push_back("SRC_TYPE=i32");
|
||||
defines.push_back("DST_TYPE=i32");
|
||||
@@ -794,6 +933,128 @@ class ggml_webgpu_shader_lib {
|
||||
return scale_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_solve_tri_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.n = (int) context.src0->ne[0],
|
||||
.k = (int) context.src1->ne[0],
|
||||
};
|
||||
|
||||
auto it = solve_tri_pipelines.find(key);
|
||||
if (it != solve_tri_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "solve_tri";
|
||||
|
||||
switch (key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
variant += "_f32";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for solve_tri shader");
|
||||
}
|
||||
|
||||
const uint32_t wg_size = std::min((uint32_t) key.n, context.max_wg_size);
|
||||
const uint32_t k_tile = wg_size;
|
||||
const uint32_t bytes_per_row = ((uint32_t) key.n + wg_size) * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
const uint32_t batch_n = (uint32_t) (context.wg_mem_limit_bytes / bytes_per_row);
|
||||
|
||||
defines.push_back(std::string("N=") + std::to_string(key.n));
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
defines.push_back(std::string("K_TILE=") + std::to_string(k_tile));
|
||||
defines.push_back(std::string("BATCH_N=") + std::to_string(batch_n));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_solve_tri, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = wg_size;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
solve_tri_pipelines[key] = pipeline;
|
||||
return solve_tri_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_ssm_conv_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.vectorized = context.src1->ne[0] == 4,
|
||||
};
|
||||
|
||||
auto it = ssm_conv_pipelines.find(key);
|
||||
if (it != ssm_conv_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "ssm_conv";
|
||||
|
||||
switch (key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
variant += "_f32";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for ssm_conv shader");
|
||||
}
|
||||
|
||||
if (key.vectorized) {
|
||||
defines.push_back("VECTORIZED");
|
||||
variant += "_vec4";
|
||||
}
|
||||
|
||||
constexpr uint32_t block_size = 32u;
|
||||
constexpr uint32_t tokens_per_wg = 8u;
|
||||
|
||||
defines.push_back("BLOCK_SIZE=" + std::to_string(block_size) + "u");
|
||||
defines.push_back("TOKENS_PER_WG=" + std::to_string(tokens_per_wg) + "u");
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_ssm_conv, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_ssm_conv_shader_decisions>();
|
||||
decisions->block_size = block_size;
|
||||
decisions->tokens_per_wg = tokens_per_wg;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
ssm_conv_pipelines[key] = pipeline;
|
||||
return ssm_conv_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_gated_delta_net_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.s_v = (int) context.src2->ne[0],
|
||||
.kda = context.src3->ne[0] == context.src2->ne[0],
|
||||
};
|
||||
|
||||
auto it = gated_delta_net_pipelines.find(key);
|
||||
if (it != gated_delta_net_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "gated_delta_net";
|
||||
|
||||
switch (key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
variant += "_f32";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for gated_delta_net shader");
|
||||
}
|
||||
|
||||
if (key.kda) {
|
||||
defines.push_back("KDA");
|
||||
variant += "_kda";
|
||||
}
|
||||
|
||||
defines.push_back("S_V=" + std::to_string(key.s_v) + "u");
|
||||
defines.push_back("WG_SIZE=" + std::to_string(key.s_v) + "u");
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_gated_delta_net, defines);
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
gated_delta_net_pipelines[key] = pipeline;
|
||||
return gated_delta_net_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 };
|
||||
|
||||
|
||||
@@ -880,6 +880,68 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g
|
||||
params, entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const bool inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src0,
|
||||
.src1 = src1,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
.inplace = inplace,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
const uint32_t ne = inplace ? (uint32_t) ggml_nelements(src1) : (uint32_t) ggml_nelements(dst);
|
||||
const uint32_t dst_type_size = (uint32_t) ggml_type_size(dst->type);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
ne,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (((const int32_t *) dst->op_params)[3] / dst_type_size),
|
||||
|
||||
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
|
||||
1u,
|
||||
(uint32_t) (((const int32_t *) dst->op_params)[0] / dst_type_size),
|
||||
(uint32_t) (((const int32_t *) dst->op_params)[1] / dst_type_size),
|
||||
(uint32_t) (((const int32_t *) dst->op_params)[2] / dst_type_size),
|
||||
|
||||
(uint32_t) src1->ne[0],
|
||||
(uint32_t) src1->ne[1],
|
||||
(uint32_t) src1->ne[2],
|
||||
(uint32_t) src1->ne[3],
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries;
|
||||
uint32_t binding_index = 0;
|
||||
if (!inplace) {
|
||||
entries.push_back({ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) });
|
||||
binding_index++;
|
||||
}
|
||||
entries.push_back({ .binding = binding_index,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) });
|
||||
entries.push_back({ .binding = binding_index + 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
|
||||
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
||||
@@ -935,6 +997,208 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_solve_tri(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src0,
|
||||
.src1 = src1,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_solve_tri_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
|
||||
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||
|
||||
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
|
||||
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||
|
||||
(uint32_t) src1->ne[0],
|
||||
(uint32_t) dst->ne[2],
|
||||
(uint32_t) dst->ne[3],
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size);
|
||||
const uint32_t wg_y = (uint32_t) (dst->ne[2] * dst->ne[3]);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_ssm_conv(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src0,
|
||||
.src1 = src1,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_conv_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_ssm_conv_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
const uint32_t token_tiles = CEIL_DIV((uint32_t) dst->ne[1], decisions->tokens_per_wg);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
|
||||
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||
|
||||
(uint32_t) src1->ne[0],
|
||||
(uint32_t) src0->ne[1],
|
||||
(uint32_t) dst->ne[1],
|
||||
(uint32_t) dst->ne[2],
|
||||
token_tiles,
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size);
|
||||
const uint32_t wg_y = token_tiles * (uint32_t) dst->ne[2];
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_gated_delta_net(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * src2,
|
||||
ggml_tensor * src3,
|
||||
ggml_tensor * src4,
|
||||
ggml_tensor * src5,
|
||||
ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src0,
|
||||
.src1 = src1,
|
||||
.src2 = src2,
|
||||
.src3 = src3,
|
||||
.src4 = src4,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_gated_delta_net_pipeline(shader_lib_ctx);
|
||||
|
||||
const uint32_t s_v = (uint32_t) src2->ne[0];
|
||||
const uint32_t h = (uint32_t) src2->ne[1];
|
||||
const uint32_t n_tokens = (uint32_t) src2->ne[2];
|
||||
const uint32_t n_seqs = (uint32_t) src2->ne[3];
|
||||
const float scale = 1.0f / sqrtf((float) s_v);
|
||||
uint32_t scale_u32;
|
||||
memcpy(&scale_u32, &scale, sizeof(scale_u32));
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
h,
|
||||
n_tokens,
|
||||
n_seqs,
|
||||
s_v * h * n_tokens * n_seqs,
|
||||
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||
|
||||
(uint32_t) (src2->nb[1] / ggml_type_size(src2->type)),
|
||||
(uint32_t) (src2->nb[2] / ggml_type_size(src2->type)),
|
||||
(uint32_t) (src2->nb[3] / ggml_type_size(src2->type)),
|
||||
|
||||
(uint32_t) (src4->nb[1] / ggml_type_size(src4->type)),
|
||||
(uint32_t) (src4->nb[2] / ggml_type_size(src4->type)),
|
||||
(uint32_t) (src4->nb[3] / ggml_type_size(src4->type)),
|
||||
|
||||
(uint32_t) src0->ne[1],
|
||||
(uint32_t) (src2->ne[3] / src0->ne[3]),
|
||||
scale_u32,
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(src2),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src2) },
|
||||
{ .binding = 3,
|
||||
.buffer = ggml_webgpu_tensor_buf(src3),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src3),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src3) },
|
||||
{ .binding = 4,
|
||||
.buffer = ggml_webgpu_tensor_buf(src4),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src4),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src4) },
|
||||
{ .binding = 5,
|
||||
.buffer = ggml_webgpu_tensor_buf(src5),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src5),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src5) },
|
||||
{ .binding = 6,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, h, n_seqs);
|
||||
}
|
||||
|
||||
static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
||||
ggml_tensor * src,
|
||||
ggml_tensor * idx,
|
||||
@@ -1016,6 +1280,8 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
|
||||
ggml_tensor * src,
|
||||
ggml_tensor * idx,
|
||||
ggml_tensor * dst) {
|
||||
const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32;
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src,
|
||||
.src1 = nullptr,
|
||||
@@ -1060,7 +1326,10 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size);
|
||||
uint32_t blocks_per_row = (uint32_t) (dst->ne[0] / (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0 ? 4 : 1));
|
||||
uint32_t total_rows = (uint32_t) (dst->ne[1] * dst->ne[2] * dst->ne[3]);
|
||||
uint32_t total_threads = float_parallel ? blocks_per_row * total_rows : total_rows;
|
||||
uint32_t wg_x = CEIL_DIV(total_threads, decisions->wg_size);
|
||||
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
@@ -1632,7 +1901,7 @@ static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * s
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src,
|
||||
.dst = dst,
|
||||
.max_wg_size = WEBGPU_ROW_SPLIT_WG_SIZE,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
.inplace = inplace,
|
||||
};
|
||||
|
||||
@@ -2176,6 +2445,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
return ggml_webgpu_cpy(ctx, src0, node);
|
||||
case GGML_OP_SET:
|
||||
return ggml_webgpu_set(ctx, src0, src1, node);
|
||||
case GGML_OP_SET_ROWS:
|
||||
return ggml_webgpu_set_rows(ctx, src0, src1, node);
|
||||
case GGML_OP_GET_ROWS:
|
||||
@@ -2219,6 +2490,12 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
||||
case GGML_OP_DIAG:
|
||||
case GGML_OP_TRI:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
return ggml_webgpu_solve_tri(ctx, src0, src1, node);
|
||||
case GGML_OP_SSM_CONV:
|
||||
return ggml_webgpu_ssm_conv(ctx, src0, src1, node);
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node);
|
||||
case GGML_OP_PAD:
|
||||
return ggml_webgpu_pad(ctx, src0, node);
|
||||
case GGML_OP_ARGMAX:
|
||||
@@ -2957,7 +3234,7 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
|
||||
/* .is_host = */ NULL, // defaults to false
|
||||
},
|
||||
/* .device = */
|
||||
dev,
|
||||
dev,
|
||||
/* .context = */ NULL
|
||||
};
|
||||
|
||||
@@ -3040,6 +3317,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) ||
|
||||
(op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32);
|
||||
break;
|
||||
case GGML_OP_SET:
|
||||
supports_op = src0->type == src1->type && src0->type == op->type &&
|
||||
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32);
|
||||
break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 &&
|
||||
(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
|
||||
@@ -3180,6 +3461,27 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
}
|
||||
}
|
||||
break;
|
||||
case GGML_OP_TRI:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_OP_DIAG:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
supports_op = op->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
{
|
||||
const uint32_t s_v = (uint32_t) src2->ne[0];
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 &&
|
||||
src2->type == GGML_TYPE_F32 && op->src[3]->type == GGML_TYPE_F32 &&
|
||||
op->src[4]->type == GGML_TYPE_F32 && op->src[5]->type == GGML_TYPE_F32 &&
|
||||
s_v <= ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_CLAMP:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
||||
break;
|
||||
@@ -3201,12 +3503,6 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
case GGML_OP_COS:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
||||
break;
|
||||
case GGML_OP_DIAG:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
||||
break;
|
||||
case GGML_OP_TRI:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
||||
break;
|
||||
case GGML_OP_PAD:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||
break;
|
||||
|
||||
132
ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl
Normal file
132
ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl
Normal file
@@ -0,0 +1,132 @@
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src_q: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src_k: array<f32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> src_v: array<f32>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<storage, read_write> src_g: array<f32>;
|
||||
|
||||
@group(0) @binding(4)
|
||||
var<storage, read_write> src_beta: array<f32>;
|
||||
|
||||
@group(0) @binding(5)
|
||||
var<storage, read_write> src_state: array<f32>;
|
||||
|
||||
@group(0) @binding(6)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
struct Params {
|
||||
h: u32,
|
||||
n_tokens: u32,
|
||||
n_seqs: u32,
|
||||
s_off: u32,
|
||||
|
||||
sq1: u32,
|
||||
sq2: u32,
|
||||
sq3: u32,
|
||||
|
||||
sv1: u32,
|
||||
sv2: u32,
|
||||
sv3: u32,
|
||||
|
||||
sb1: u32,
|
||||
sb2: u32,
|
||||
sb3: u32,
|
||||
|
||||
neq1: u32,
|
||||
rq3: u32,
|
||||
scale: f32,
|
||||
};
|
||||
|
||||
@group(0) @binding(7)
|
||||
var<uniform> params: Params;
|
||||
|
||||
var<workgroup> sh_k: array<f32, S_V>;
|
||||
var<workgroup> sh_q: array<f32, S_V>;
|
||||
var<workgroup> sh_g: array<f32, S_V>;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>
|
||||
) {
|
||||
let head_id = workgroup_id.x;
|
||||
let seq_id = workgroup_id.y;
|
||||
let col = local_id.x;
|
||||
|
||||
let iq1 = head_id % params.neq1;
|
||||
let iq3 = seq_id / params.rq3;
|
||||
|
||||
let state_size = S_V * S_V;
|
||||
let state_base = (seq_id * params.h + head_id) * state_size;
|
||||
|
||||
var state: array<f32, S_V>;
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
state[i] = src_state[state_base + col * S_V + i];
|
||||
}
|
||||
|
||||
var attn_off = (seq_id * params.n_tokens * params.h + head_id) * S_V;
|
||||
|
||||
for (var t = 0u; t < params.n_tokens; t++) {
|
||||
let q_off = iq3 * params.sq3 + t * params.sq2 + iq1 * params.sq1;
|
||||
let k_off = q_off;
|
||||
let v_off = seq_id * params.sv3 + t * params.sv2 + head_id * params.sv1;
|
||||
let gb_off = seq_id * params.sb3 + t * params.sb2 + head_id * params.sb1;
|
||||
|
||||
sh_q[col] = src_q[q_off + col];
|
||||
sh_k[col] = src_k[k_off + col];
|
||||
|
||||
#ifdef KDA
|
||||
let g_base = gb_off * S_V;
|
||||
sh_g[col] = exp(src_g[g_base + col]);
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
let v_val = src_v[v_off + col];
|
||||
let beta_val = src_beta[gb_off];
|
||||
|
||||
var kv_col = 0.0;
|
||||
var delta_col = 0.0;
|
||||
var attn_col = 0.0;
|
||||
|
||||
#ifdef KDA
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
kv_col += (sh_g[i] * state[i]) * sh_k[i];
|
||||
}
|
||||
|
||||
delta_col = (v_val - kv_col) * beta_val;
|
||||
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
state[i] = sh_g[i] * state[i] + sh_k[i] * delta_col;
|
||||
attn_col += state[i] * sh_q[i];
|
||||
}
|
||||
#else
|
||||
let g_val = exp(src_g[gb_off]);
|
||||
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
kv_col += state[i] * sh_k[i];
|
||||
}
|
||||
|
||||
delta_col = (v_val - g_val * kv_col) * beta_val;
|
||||
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
state[i] = g_val * state[i] + sh_k[i] * delta_col;
|
||||
attn_col += state[i] * sh_q[i];
|
||||
}
|
||||
#endif
|
||||
|
||||
dst[attn_off + col] = attn_col * params.scale;
|
||||
attn_off += S_V * params.h;
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
dst[params.s_off + state_base + col * S_V + i] = state[i];
|
||||
}
|
||||
}
|
||||
@@ -640,6 +640,35 @@ var<uniform> params: Params;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
#ifdef FLOAT_PARALLEL
|
||||
let blocks_per_row = params.ne0 / BLOCK_SIZE;
|
||||
let row_count = params.n_rows * params.ne2 * params.ne3;
|
||||
|
||||
if (gid.x >= blocks_per_row * row_count) {
|
||||
return;
|
||||
}
|
||||
|
||||
let block_idx = gid.x % blocks_per_row;
|
||||
var row_idx = gid.x / blocks_per_row;
|
||||
let i_dst3 = row_idx / (params.ne2 * params.n_rows);
|
||||
|
||||
row_idx = row_idx % (params.ne2 * params.n_rows);
|
||||
let i_dst2 = row_idx / params.n_rows;
|
||||
let i_dst1 = row_idx % params.n_rows;
|
||||
|
||||
let i_idx2 = i_dst3 % params.idx2;
|
||||
let i_idx1 = i_dst2 % params.idx1;
|
||||
let i_idx0 = i_dst1;
|
||||
|
||||
let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
|
||||
|
||||
let idx_val = u32(idx[i_idx]);
|
||||
|
||||
let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;
|
||||
let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;
|
||||
|
||||
copy_elements(i_src_row, i_dst_row, block_idx);
|
||||
#else
|
||||
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
|
||||
return;
|
||||
}
|
||||
@@ -664,5 +693,5 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) {
|
||||
copy_elements(i_src_row, i_dst_row, i);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -81,11 +81,12 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
}
|
||||
sum = scratch[0];
|
||||
|
||||
#ifdef OP_RMS_NORM
|
||||
#ifdef RMS_NORM
|
||||
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
|
||||
#elif OP_L2_NORM
|
||||
#elif defined(L2_NORM)
|
||||
let scale = 1.0/max(sqrt(sum), params.eps);
|
||||
#endif
|
||||
|
||||
col = lid.x;
|
||||
for (var j: u32 = 0; j < elems; j++) {
|
||||
if (col >= params.ne0) {
|
||||
|
||||
109
ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl
Normal file
109
ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl
Normal file
@@ -0,0 +1,109 @@
|
||||
#ifdef TYPE_I32
|
||||
#define TYPE i32
|
||||
#else
|
||||
#define TYPE f32
|
||||
#endif
|
||||
|
||||
#ifndef INPLACE
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<TYPE>;
|
||||
#define SRC1_BINDING 1
|
||||
#else
|
||||
#define SRC1_BINDING 0
|
||||
#endif
|
||||
|
||||
#define DST_BINDING SRC1_BINDING + 1
|
||||
#define PARAMS_BINDING SRC1_BINDING + 2
|
||||
|
||||
@group(0) @binding(SRC1_BINDING)
|
||||
var<storage, read_write> src1: array<TYPE>;
|
||||
|
||||
@group(0) @binding(DST_BINDING)
|
||||
var<storage, read_write> dst: array<TYPE>;
|
||||
|
||||
struct Params {
|
||||
ne: u32,
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_view: u32,
|
||||
|
||||
stride_src10: u32,
|
||||
stride_src11: u32,
|
||||
stride_src12: u32,
|
||||
stride_src13: u32,
|
||||
|
||||
stride_dst10: u32,
|
||||
stride_dst11: u32,
|
||||
stride_dst12: u32,
|
||||
stride_dst13: u32,
|
||||
|
||||
src1_ne0: u32,
|
||||
src1_ne1: u32,
|
||||
src1_ne2: u32,
|
||||
src1_ne3: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(PARAMS_BINDING)
|
||||
var<uniform> params: Params;
|
||||
|
||||
fn decode_src1_coords(idx: u32) -> vec4<u32> {
|
||||
var i = idx;
|
||||
let plane = params.src1_ne2 * params.src1_ne1 * params.src1_ne0;
|
||||
let i3 = i / plane;
|
||||
i = i % plane;
|
||||
let row = params.src1_ne1 * params.src1_ne0;
|
||||
let i2 = i / row;
|
||||
i = i % row;
|
||||
let i1 = i / params.src1_ne0;
|
||||
let i0 = i % params.src1_ne0;
|
||||
return vec4<u32>(i0, i1, i2, i3);
|
||||
}
|
||||
|
||||
fn decode_view_coords(rel: u32) -> vec4<u32> {
|
||||
let i3 = rel / params.stride_dst13;
|
||||
let rem3 = rel % params.stride_dst13;
|
||||
let i2 = rem3 / params.stride_dst12;
|
||||
let rem2 = rem3 % params.stride_dst12;
|
||||
let i1 = rem2 / params.stride_dst11;
|
||||
let i0 = rem2 % params.stride_dst11;
|
||||
return vec4<u32>(i0, i1, i2, i3);
|
||||
}
|
||||
|
||||
fn view_rel_from_coords(coords: vec4<u32>) -> u32 {
|
||||
return coords.x * params.stride_dst10 + coords.y * params.stride_dst11 +
|
||||
coords.z * params.stride_dst12 + coords.w * params.stride_dst13;
|
||||
}
|
||||
|
||||
fn src1_idx_from_coords(coords: vec4<u32>) -> u32 {
|
||||
return coords.x * params.stride_src10 + coords.y * params.stride_src11 +
|
||||
coords.z * params.stride_src12 + coords.w * params.stride_src13;
|
||||
}
|
||||
|
||||
fn in_set_view(rel: u32, coords: vec4<u32>) -> bool {
|
||||
return view_rel_from_coords(coords) == rel;
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef INPLACE
|
||||
let coords = decode_src1_coords(gid.x);
|
||||
|
||||
let src1_idx = params.offset_src1 + src1_idx_from_coords(coords);
|
||||
let dst_idx = params.offset_view + view_rel_from_coords(coords);
|
||||
|
||||
dst[dst_idx] = src1[src1_idx];
|
||||
#else
|
||||
let rel = select(params.ne, gid.x - params.offset_view, gid.x >= params.offset_view);
|
||||
let coords = decode_view_coords(rel);
|
||||
|
||||
if (rel < params.stride_dst13 * params.src1_ne3 && in_set_view(rel, coords)) {
|
||||
dst[gid.x] = src1[params.offset_src1 + src1_idx_from_coords(coords)];
|
||||
} else {
|
||||
dst[gid.x] = src0[params.offset_src0 + gid.x];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
121
ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl
Normal file
121
ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl
Normal file
@@ -0,0 +1,121 @@
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1: array<f32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
struct Params {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
stride_src00: u32,
|
||||
stride_src01: u32,
|
||||
stride_src02: u32,
|
||||
stride_src03: u32,
|
||||
|
||||
stride_src10: u32,
|
||||
stride_src11: u32,
|
||||
stride_src12: u32,
|
||||
stride_src13: u32,
|
||||
|
||||
stride_dst0: u32,
|
||||
stride_dst1: u32,
|
||||
stride_dst2: u32,
|
||||
stride_dst3: u32,
|
||||
|
||||
k: u32,
|
||||
ne2: u32,
|
||||
ne3: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
var<workgroup> shA: array<f32, BATCH_N * N>;
|
||||
var<workgroup> shB: array<f32, BATCH_N * K_TILE>;
|
||||
|
||||
fn src0_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
|
||||
return params.offset_src0 +
|
||||
col * params.stride_src00 +
|
||||
row * params.stride_src01 +
|
||||
i2 * params.stride_src02 +
|
||||
i3 * params.stride_src03;
|
||||
}
|
||||
|
||||
fn src1_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
|
||||
return params.offset_src1 +
|
||||
col * params.stride_src10 +
|
||||
row * params.stride_src11 +
|
||||
i2 * params.stride_src12 +
|
||||
i3 * params.stride_src13;
|
||||
}
|
||||
|
||||
fn dst_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
|
||||
return params.offset_dst +
|
||||
col * params.stride_dst0 +
|
||||
row * params.stride_dst1 +
|
||||
i2 * params.stride_dst2 +
|
||||
i3 * params.stride_dst3;
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>
|
||||
) {
|
||||
let batch = workgroup_id.y;
|
||||
let col = workgroup_id.x * WG_SIZE + local_id.x;
|
||||
let i3 = batch / params.ne2;
|
||||
let i2 = batch % params.ne2;
|
||||
let active_lane = local_id.x < K_TILE;
|
||||
let active_col = active_lane && col < params.k;
|
||||
|
||||
var X: array<f32, N>;
|
||||
|
||||
for (var row_base = 0u; row_base < N; row_base += BATCH_N) {
|
||||
let cur_n = min(BATCH_N, N - row_base);
|
||||
|
||||
for (var i = local_id.x; i < cur_n * N; i += WG_SIZE) {
|
||||
let tile_row = i / N;
|
||||
let tile_col = i % N;
|
||||
shA[i] = src0[src0_idx(row_base + tile_row, tile_col, i2, i3)];
|
||||
}
|
||||
|
||||
for (var i = local_id.x; i < cur_n * K_TILE; i += WG_SIZE) {
|
||||
let tile_row = i / K_TILE;
|
||||
let tile_col = i % K_TILE;
|
||||
let global_col = workgroup_id.x * WG_SIZE + tile_col;
|
||||
let sh_idx = tile_row * K_TILE + tile_col;
|
||||
|
||||
if (global_col < params.k) {
|
||||
shB[sh_idx] = src1[src1_idx(row_base + tile_row, global_col, i2, i3)];
|
||||
} else {
|
||||
shB[sh_idx] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
if (active_col) {
|
||||
for (var row_offset = 0u; row_offset < cur_n; row_offset++) {
|
||||
let r = row_base + row_offset;
|
||||
var b = shB[row_offset * K_TILE + local_id.x];
|
||||
let a_row = row_offset * N;
|
||||
|
||||
for (var t = 0u; t < r; t++) {
|
||||
b -= shA[a_row + t] * X[t];
|
||||
}
|
||||
|
||||
let x = b / shA[a_row + r];
|
||||
X[r] = x;
|
||||
dst[dst_idx(r, col, i2, i3)] = x;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
}
|
||||
65
ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl
Normal file
65
ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl
Normal file
@@ -0,0 +1,65 @@
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1: array<f32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
struct Params {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
stride_src01: u32,
|
||||
stride_src02: u32,
|
||||
stride_src11: u32,
|
||||
|
||||
stride_dst0: u32,
|
||||
stride_dst1: u32,
|
||||
stride_dst2: u32,
|
||||
|
||||
nc: u32,
|
||||
nr: u32,
|
||||
n_t: u32,
|
||||
n_s: u32,
|
||||
token_tiles: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
@compute @workgroup_size(BLOCK_SIZE, TOKENS_PER_WG)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let i1 = gid.x;
|
||||
let tile_y = gid.y / TOKENS_PER_WG;
|
||||
let local_token = gid.y % TOKENS_PER_WG;
|
||||
let i3 = tile_y / params.token_tiles;
|
||||
let token_tile = tile_y % params.token_tiles;
|
||||
let i2 = token_tile * TOKENS_PER_WG + local_token;
|
||||
|
||||
if (i1 >= params.nr || i2 >= params.n_t || i3 >= params.n_s) {
|
||||
return;
|
||||
}
|
||||
|
||||
let src0_base = params.offset_src0 + i3 * params.stride_src02 + i2 + i1 * params.stride_src01;
|
||||
let src1_base = params.offset_src1 + i1 * params.stride_src11;
|
||||
|
||||
var sum = 0.0;
|
||||
|
||||
#ifdef VECTORIZED
|
||||
sum =
|
||||
src0[src0_base + 0u] * src1[src1_base + 0u] +
|
||||
src0[src0_base + 1u] * src1[src1_base + 1u] +
|
||||
src0[src0_base + 2u] * src1[src1_base + 2u] +
|
||||
src0[src0_base + 3u] * src1[src1_base + 3u];
|
||||
#else
|
||||
for (var i0 = 0u; i0 < params.nc; i0++) {
|
||||
sum += src0[src0_base + i0] * src1[src1_base + i0];
|
||||
}
|
||||
#endif
|
||||
|
||||
let dst_idx = params.offset_dst + i3 * params.stride_dst2 + i2 * params.stride_dst1 + i1 * params.stride_dst0;
|
||||
dst[dst_idx] = sum;
|
||||
}
|
||||
@@ -301,6 +301,8 @@ class Keys:
|
||||
IMAGE_SIZE = "clip.vision.image_size"
|
||||
IMAGE_MIN_PIXELS = "clip.vision.image_min_pixels"
|
||||
IMAGE_MAX_PIXELS = "clip.vision.image_max_pixels"
|
||||
PREPROC_MIN_TILES = "clip.vision.preproc_min_tiles"
|
||||
PREPROC_MAX_TILES = "clip.vision.preproc_max_tiles"
|
||||
PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size"
|
||||
PATCH_SIZE = "clip.vision.patch_size"
|
||||
EMBEDDING_LENGTH = "clip.vision.embedding_length"
|
||||
@@ -3869,6 +3871,8 @@ class LlamaFileType(IntEnum):
|
||||
# MOSTLY_Q4_0_8_8 = 35 # removed from gguf files, use Q4_0 and runtime repack
|
||||
MOSTLY_TQ1_0 = 36 # except 1d tensors
|
||||
MOSTLY_TQ2_0 = 37 # except 1d tensors
|
||||
MOSTLY_MXFP4_MOE = 38 # except 1d tensors
|
||||
MOSTLY_NVFP4 = 39 # except 1d tensors
|
||||
|
||||
GUESSED = 1024 # not specified in the model file
|
||||
|
||||
|
||||
@@ -1156,6 +1156,12 @@ class GGUFWriter:
|
||||
def add_vision_min_pixels(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.IMAGE_MIN_PIXELS, value)
|
||||
|
||||
def add_vision_preproc_max_tiles(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.PREPROC_MAX_TILES, value)
|
||||
|
||||
def add_vision_preproc_min_tiles(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.PREPROC_MIN_TILES, value)
|
||||
|
||||
def add_vision_preproc_image_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.PREPROC_IMAGE_SIZE, value)
|
||||
|
||||
@@ -1300,7 +1306,7 @@ class GGUFWriter:
|
||||
else:
|
||||
raise ValueError("Invalid GGUF metadata value type or value")
|
||||
|
||||
return kv_data
|
||||
return bytes(kv_data)
|
||||
|
||||
@staticmethod
|
||||
def format_n_bytes_to_str(num: int) -> str:
|
||||
|
||||
@@ -138,7 +138,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
||||
if isinstance(meta_noop, tuple):
|
||||
dtype, shape = meta_noop
|
||||
assert callable(shape)
|
||||
res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape))
|
||||
res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape)) # ty: ignore[call-top-callable]
|
||||
else:
|
||||
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
|
||||
|
||||
|
||||
@@ -91,11 +91,11 @@ class __Quant(ABC):
|
||||
def __init_subclass__(cls, qtype: GGMLQuantizationType) -> None:
|
||||
cls.qtype = qtype
|
||||
cls.block_size, cls.type_size = GGML_QUANT_SIZES[qtype]
|
||||
cls.__quantize_lazy = LazyNumpyTensor._wrap_fn(
|
||||
cls.__quantize_lazy: Any = LazyNumpyTensor._wrap_fn(
|
||||
cls.__quantize_array,
|
||||
meta_noop=(np.uint8, cls.__shape_to_bytes)
|
||||
)
|
||||
cls.__dequantize_lazy = LazyNumpyTensor._wrap_fn(
|
||||
cls.__dequantize_lazy: Any = LazyNumpyTensor._wrap_fn(
|
||||
cls.__dequantize_array,
|
||||
meta_noop=(np.float32, cls.__shape_from_bytes)
|
||||
)
|
||||
|
||||
@@ -11,33 +11,33 @@ from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVa
|
||||
try:
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
except ImportError:
|
||||
SentencePieceProcessor = None
|
||||
SentencePieceProcessor: Any = None
|
||||
|
||||
try:
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # type: ignore[import-not-found]
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # type: ignore[import-not-found]
|
||||
from mistral_common.tokens.tokenizers.utils import ( # type: ignore[import-not-found]
|
||||
_filter_valid_tokenizer_files,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import ( # type: ignore[import-not-found]
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
except ImportError:
|
||||
_mistral_common_installed = False
|
||||
MistralTokenizer = None
|
||||
Tekkenizer = None
|
||||
SentencePieceTokenizer = None
|
||||
_filter_valid_tokenizer_files = None
|
||||
MistralTokenizer: Any = None
|
||||
Tekkenizer: Any = None
|
||||
SentencePieceTokenizer: Any = None
|
||||
_filter_valid_tokenizer_files: Any = None
|
||||
else:
|
||||
_mistral_common_installed = True
|
||||
|
||||
try:
|
||||
from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.utils import ( # type: ignore[import-not-found]
|
||||
get_one_valid_tokenizer_file,
|
||||
)
|
||||
except ImportError:
|
||||
# We still want the conversion to work with older mistral-common versions.
|
||||
get_one_valid_tokenizer_file = None
|
||||
get_one_valid_tokenizer_file: Any = None
|
||||
|
||||
|
||||
import gguf
|
||||
@@ -703,7 +703,7 @@ class MistralVocab(Vocab):
|
||||
|
||||
tokenizer_file_path = base_path / tokenizer_file
|
||||
|
||||
self.tokenizer = MistralTokenizer.from_file(
|
||||
self.tokenizer: Any = MistralTokenizer.from_file(
|
||||
tokenizer_file_path
|
||||
).instruct_tokenizer.tokenizer
|
||||
self.tokenizer_type = (
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
{%- set available_tool_string = '' -%}
|
||||
{%- set add_tool_id = true -%}
|
||||
{%- set add_thoughts = true -%} {# whether to include <thinking> reasoning blocks #}
|
||||
{%- set add_generation_prompt = true -%} {# whether to emit reasoning starter before assistant response #}
|
||||
{# Optional token placeholders (safe defaults) #}
|
||||
{%- set bos_token = bos_token or '' -%}
|
||||
{%- set eos_token = eos_token or '' -%}
|
||||
|
||||
@@ -15,10 +15,10 @@
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- for tool in message['tool_calls']-%}
|
||||
{%- if not ns.is_first -%}
|
||||
{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||
{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] | tojson + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||
{%- set ns.is_first = true -%}
|
||||
{%- else -%}
|
||||
{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||
{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] | tojson + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}
|
||||
|
||||
@@ -28,25 +28,25 @@
|
||||
{%- set ns.is_last_user = true -%}{{'<|User|>' + message['content']}}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'assistant' and message['tool_calls'] -%}
|
||||
{%- if ns.is_last_user -%}{{'<|Assistant|></think>'}}
|
||||
{%- if ns.is_last_user -%}{{'<|Assistant|><think></think>'}}
|
||||
{%- endif -%}
|
||||
{%- set ns.is_last_user = false -%}
|
||||
{%- set ns.is_first = false -%}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- for tool in message['tool_calls'] -%}
|
||||
{%- if not ns.is_first -%}
|
||||
{%- if not message['content'] -%}{{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}
|
||||
{%- else -%}{{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}
|
||||
{%- if not message['content'] -%}{{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] | tojson + '<|tool▁call▁end|>'}}
|
||||
{%- else -%}{{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] | tojson + '<|tool▁call▁end|>'}}
|
||||
{%- endif -%}
|
||||
{%- set ns.is_first = true -%}
|
||||
{%- else -%}{{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}
|
||||
{%- else -%}{{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] | tojson + '<|tool▁call▁end|>'}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'assistant' and not message['tool_calls'] -%}
|
||||
{%- if ns.is_last_user -%}{{'<|Assistant|>'}}
|
||||
{%- if message['prefix'] is defined and message['prefix'] and thinking -%}{{'<think>'}}
|
||||
{%- else -%}{{'</think>'}}
|
||||
{%- else -%}{{'<think></think>'}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- set ns.is_last_user = false -%}
|
||||
@@ -65,7 +65,7 @@
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool -%}{{'<|Assistant|>'}}
|
||||
{%- if not thinking -%}{{'</think>'}}
|
||||
{%- else -%}{{'<think>'}}
|
||||
{%- if not thinking -%}{{'<think></think>'}}
|
||||
{%- else -%}{{'<think>'}}
|
||||
{%- endif -%}
|
||||
{%- endif %}
|
||||
@@ -49,7 +49,7 @@ Example function tool call syntax:
|
||||
{%- endif -%}
|
||||
{%- set tool_name = tc['function']['name'] -%}
|
||||
{%- set tool_args = tc['function']['arguments'] -%}
|
||||
{{- '<|tool▁call▁begin|>' + tc['type'] + '<|tool▁sep|>' + tool_name + '\n' + '```json' + '\n' + tool_args + '\n' + '```' + '<|tool▁call▁end|>' -}}
|
||||
{{- '<|tool▁call▁begin|>' + tc['type'] + '<|tool▁sep|>' + tool_name + '\n' + '```json' + '\n' + tool_args | tojson + '\n' + '```' + '<|tool▁call▁end|>' -}}
|
||||
{%- endfor -%}
|
||||
{{- '<|tool▁calls▁end|><|end▁of▁sentence|>' -}}
|
||||
{%- endif -%}
|
||||
|
||||
@@ -42,9 +42,9 @@
|
||||
{%- if 'tool_calls' in message and message['tool_calls'] -%}
|
||||
{%- for tool_call in message['tool_calls'] -%}
|
||||
{%- if tool_call["function"]["name"] == "python" -%}
|
||||
{{ '<|python_tag|>' + tool_call['function']['arguments'] }}
|
||||
{{ '<|python_tag|>' + tool_call['function']['arguments'] | tojson }}
|
||||
{%- else -%}
|
||||
{{ '<function=' + tool_call['function']['name'] + '>' + tool_call['function']['arguments'] + '</function>' }}
|
||||
{{ '<function=' + tool_call['function']['name'] + '>' + tool_call['function']['arguments'] | tojson + '</function>' }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{ '<|eom_id|>' }}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"extraPaths": ["gguf-py", "examples/model-conversion/scripts"],
|
||||
"extraPaths": ["gguf-py", "examples/model-conversion/scripts", "examples/model-conversion/scripts/utils"],
|
||||
"pythonVersion": "3.9",
|
||||
"pythonPlatform": "All",
|
||||
"reportUnusedImport": "warning",
|
||||
|
||||
@@ -684,6 +684,7 @@ else:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
assert isinstance(hexsha8_baseline, str)
|
||||
name_baseline = bench_data.get_commit_name(hexsha8_baseline)
|
||||
|
||||
hexsha8_compare = name_compare = None
|
||||
@@ -717,6 +718,7 @@ else:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
assert isinstance(hexsha8_compare, str)
|
||||
name_compare = bench_data.get_commit_name(hexsha8_compare)
|
||||
|
||||
# Get tool-specific configuration
|
||||
|
||||
@@ -95,9 +95,9 @@ if __name__ == '__main__':
|
||||
'-p', 'Hey',
|
||||
'--no-warmup',
|
||||
'--log-disable',
|
||||
'-no-cnv']
|
||||
'-st']
|
||||
if m.hf_file != 'tinyllamas/stories260K.gguf' and 'Mistral-Nemo' not in m.hf_repo:
|
||||
cmd.append('-fa')
|
||||
cmd += ('-fa', 'on')
|
||||
try:
|
||||
subprocess.check_call(cmd)
|
||||
except subprocess.CalledProcessError:
|
||||
|
||||
157
scripts/hip/gcn-cdna-vgpr-check.py
Normal file
157
scripts/hip/gcn-cdna-vgpr-check.py
Normal file
@@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def parse_log_file(filepath):
|
||||
"""Parse log file and extract function VGPR usage."""
|
||||
import re
|
||||
|
||||
functions = defaultdict(lambda: {'vgprs': 0, 'spill': 0, 'location': ''})
|
||||
|
||||
try:
|
||||
with open(filepath, 'r') as f:
|
||||
content = f.read()
|
||||
# Find all function entries with VGPR usage including location
|
||||
pattern = r'([^:]+:\d+):.*?Function Name: (\S+).*?VGPRs: (\d+).*?VGPRs Spill: (\d+)'
|
||||
matches = re.findall(pattern, content, re.DOTALL)
|
||||
|
||||
for location, func_name, vgprs, spill in matches:
|
||||
functions[func_name]['vgprs'] = int(vgprs)
|
||||
functions[func_name]['spill'] = int(spill)
|
||||
# Extract just the filename and line number
|
||||
parts = location.split('/')
|
||||
if len(parts) > 0:
|
||||
short_location = parts[-1] # Get last part (filename)
|
||||
# Check if there's a line number after filename
|
||||
if ':' in short_location:
|
||||
functions[func_name]['location'] = short_location
|
||||
else:
|
||||
functions[func_name]['location'] = location
|
||||
else:
|
||||
functions[func_name]['location'] = location
|
||||
except FileNotFoundError:
|
||||
print(f"Error: File {filepath} not found", file=sys.stderr) # noqa: NP100
|
||||
sys.exit(1)
|
||||
|
||||
return functions
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: ./vgpr_check.py <log_file>", file=sys.stderr) # noqa: NP100
|
||||
sys.exit(1)
|
||||
|
||||
log_file = sys.argv[1]
|
||||
ignored = {
|
||||
'_ZL21gated_linear_attn_f32ILi128EEviiiifPKfS1_S1_S1_S1_Pf',
|
||||
'_ZL18flash_attn_ext_f16ILi64ELi64ELi16ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi80ELi80ELi16ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi96ELi96ELi16ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi64ELi64ELi32ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL13rwkv_wkv7_f32ILi128EEviiiiPKfS1_S1_S1_S1_S1_S1_Pf',
|
||||
'_ZL18flash_attn_ext_f16ILi80ELi80ELi16ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi112ELi112ELi16ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi80ELi80ELi32ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi96ELi96ELi16ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi2ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi96ELi96ELi32ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi112ELi112ELi16ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi112ELi112ELi32ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi1ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi80ELi80ELi2ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi96ELi96ELi2ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi112ELi112ELi2ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi2ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi2ELi8ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi112ELi112ELi16ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi4ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi32ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi96ELi96ELi4ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi112ELi112ELi4ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi4ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi4ELi4ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi80ELi80ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi96ELi96ELi64ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi112ELi112ELi64ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi64ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi64ELi1ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi64ELi64ELi8ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi80ELi80ELi8ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi96ELi96ELi8ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi112ELi112ELi8ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi80ELi80ELi8ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi4ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi96ELi96ELi8ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi112ELi112ELi8ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi2ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi112ELi112ELi8ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi8ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL24mul_mat_q_stream_k_fixupIL9ggml_type22ELi8ELb1EEvPKiS2_PfPKfiiimimimi',
|
||||
'_ZL9mul_mat_qIL9ggml_type3ELi32ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL9mul_mat_qIL9ggml_type3ELi48ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL9mul_mat_qIL9ggml_type20ELi32ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL9mul_mat_qIL9ggml_type17ELi64ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL18flash_attn_ext_f16ILi80ELi80ELi4ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL15flash_attn_tileILi256ELi256ELi32ELi1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL9mul_mat_qIL9ggml_type19ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL9mul_mat_qIL9ggml_type17ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL9mul_mat_qIL9ggml_type22ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL9mul_mat_qIL9ggml_type19ELi128ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL9mul_mat_qIL9ggml_type19ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL9mul_mat_qIL9ggml_type7ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL9mul_mat_qIL9ggml_type3ELi128ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL9mul_mat_qIL9ggml_type3ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL9mul_mat_qIL9ggml_type7ELi128ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL9mul_mat_qIL9ggml_type7ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL9mul_mat_qIL9ggml_type11ELi112ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL9mul_mat_qIL9ggml_type11ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL24mul_mat_q_stream_k_fixupIL9ggml_type11ELi128ELb0EEvPKiS2_PfPKfiiimimimi',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi32ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL9mul_mat_qIL9ggml_type2ELi112ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL18flash_attn_ext_f16ILi112ELi112ELi32ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi112ELi112ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi32ELi1ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi32ELi2ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi128ELi128ELi4ELi8ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_f16ILi96ELi96ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
}
|
||||
|
||||
functions = parse_log_file(log_file)
|
||||
found_issues = False
|
||||
|
||||
# First print all ignored functions (deduplicated)
|
||||
printed_ignored = set()
|
||||
for func_name, data in sorted(functions.items()):
|
||||
total_vgprs = int(data['vgprs']) + int(data['spill'])
|
||||
if total_vgprs > 256 and func_name in ignored and func_name not in printed_ignored:
|
||||
location = data.get('location', log_file)
|
||||
print(f"{location}: {func_name} - Total VGPRs: {total_vgprs} ({data['vgprs']} + {data['spill']}) [IGNORED]") # noqa: NP100
|
||||
printed_ignored.add(func_name)
|
||||
|
||||
# Then print new functions with issues in red
|
||||
for func_name, data in sorted(functions.items()):
|
||||
total_vgprs = int(data['vgprs']) + int(data['spill'])
|
||||
if total_vgprs > 256 and func_name not in ignored:
|
||||
status = "[IGNORED]" if func_name in ignored else ""
|
||||
location = data.get('location', log_file)
|
||||
# Print in red if not ignored
|
||||
color_code = "\033[91m" if func_name not in ignored else ""
|
||||
reset_code = "\033[0m" if func_name not in ignored else ""
|
||||
print(f"{color_code}{location}: {func_name} - Total VGPRs: {total_vgprs} ({data['vgprs']} + {data['spill']}) {status}{reset_code}") # noqa: NP100
|
||||
if func_name not in ignored:
|
||||
found_issues = True
|
||||
|
||||
sys.exit(1 if found_issues else 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -241,10 +241,10 @@ class CodeEditor(QPlainTextEdit):
|
||||
if not self.isReadOnly():
|
||||
selection = QTextEdit.ExtraSelection()
|
||||
line_color = QColorConstants.Yellow.lighter(160)
|
||||
selection.format.setBackground(line_color) # pyright: ignore[reportAttributeAccessIssue]
|
||||
selection.format.setProperty(QTextFormat.Property.FullWidthSelection, True) # pyright: ignore[reportAttributeAccessIssue]
|
||||
selection.cursor = self.textCursor() # pyright: ignore[reportAttributeAccessIssue]
|
||||
selection.cursor.clearSelection() # pyright: ignore[reportAttributeAccessIssue]
|
||||
selection.format.setBackground(line_color) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
selection.format.setProperty(QTextFormat.Property.FullWidthSelection, True) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
selection.cursor = self.textCursor() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
selection.cursor.clearSelection() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
extra_selections.append(selection)
|
||||
self.setExtraSelections(extra_selections)
|
||||
|
||||
@@ -262,8 +262,8 @@ class CodeEditor(QPlainTextEdit):
|
||||
)
|
||||
|
||||
extra = QTextEdit.ExtraSelection()
|
||||
extra.format.setBackground(color.lighter(160)) # pyright: ignore[reportAttributeAccessIssue]
|
||||
extra.cursor = cursor # pyright: ignore[reportAttributeAccessIssue]
|
||||
extra.format.setBackground(color.lighter(160)) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
extra.cursor = cursor # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
|
||||
self.setExtraSelections(self.extraSelections() + [extra])
|
||||
|
||||
@@ -274,8 +274,8 @@ class CodeEditor(QPlainTextEdit):
|
||||
cursor.select(QTextCursor.SelectionType.LineUnderCursor)
|
||||
|
||||
extra = QTextEdit.ExtraSelection()
|
||||
extra.format.setBackground(color.lighter(160)) # pyright: ignore[reportAttributeAccessIssue]
|
||||
extra.cursor = cursor # pyright: ignore[reportAttributeAccessIssue]
|
||||
extra.format.setBackground(color.lighter(160)) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
extra.cursor = cursor # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
|
||||
self.setExtraSelections(self.extraSelections() + [extra])
|
||||
|
||||
@@ -395,8 +395,8 @@ class JinjaTester(QMainWindow):
|
||||
ensure_ascii=ensure_ascii,
|
||||
)
|
||||
)
|
||||
env.globals["strftime_now"] = lambda format: datetime.now().strftime(format)
|
||||
env.globals["raise_exception"] = raise_exception
|
||||
env.globals["strftime_now"] = lambda format: datetime.now().strftime(format) # ty: ignore[invalid-assignment]
|
||||
env.globals["raise_exception"] = raise_exception # ty: ignore[invalid-assignment]
|
||||
try:
|
||||
template = env.from_string(template_str)
|
||||
output = template.render(context)
|
||||
|
||||
@@ -189,6 +189,7 @@ def benchmark(
|
||||
|
||||
data: list[dict] = []
|
||||
|
||||
assert isinstance(prompts, list)
|
||||
for i, p in enumerate(prompts):
|
||||
if seed_offset >= 0:
|
||||
random.seed(3 * (seed_offset + 1000 * i) + 1)
|
||||
|
||||
@@ -39,6 +39,9 @@ opmask=
|
||||
nhvx=
|
||||
[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX"
|
||||
|
||||
hmx=
|
||||
[ "$HMX" != "" ] && hmx="GGML_HEXAGON_USE_HMX=$HMX"
|
||||
|
||||
ndev=
|
||||
[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV"
|
||||
|
||||
@@ -51,7 +54,7 @@ adb $adbserial $adbhost shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $ndev $hb \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $hmx $ndev $hb \
|
||||
./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --ubatch-size 256 -fa on \
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user