mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-05-14 17:07:43 +03:00
Compare commits
93 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc97e45a14 | ||
|
|
8e52631d55 | ||
|
|
f4b5a2ee91 | ||
|
|
97f06e9eed | ||
|
|
e358d75adb | ||
|
|
cfff1fc300 | ||
|
|
3980e04d5a | ||
|
|
2496f9c149 | ||
|
|
5207d120ea | ||
|
|
a0101225bc | ||
|
|
a290ce6266 | ||
|
|
a00e47e422 | ||
|
|
750141969c | ||
|
|
a736e6c0ac | ||
|
|
e3e3f8e46a | ||
|
|
f08f20a0e3 | ||
|
|
07eaf919ed | ||
|
|
74d6248f71 | ||
|
|
2ca1161bd7 | ||
|
|
bbeb89d76c | ||
|
|
ff806a110d | ||
|
|
d5003b6e4d | ||
|
|
2635ac76e8 | ||
|
|
70a8309114 | ||
|
|
c91faf997f | ||
|
|
bf76ac77be | ||
|
|
a09a00e502 | ||
|
|
2bacb1eb77 | ||
|
|
d6e7b033a4 | ||
|
|
fa595462ca | ||
|
|
a817a22bc6 | ||
|
|
eff06702b2 | ||
|
|
e77056f9b2 | ||
|
|
935a340292 | ||
|
|
d8794eecd5 | ||
|
|
36a694c965 | ||
|
|
a4701c98f7 | ||
|
|
994118a183 | ||
|
|
c84e6d6db5 | ||
|
|
fa8feaed34 | ||
|
|
846262d787 | ||
|
|
6dcd824fce | ||
|
|
d4b0c22f9e | ||
|
|
e48034dfc9 | ||
|
|
048a490f76 | ||
|
|
db44417b02 | ||
|
|
d05fe1d7da | ||
|
|
0754b7b6fe | ||
|
|
09294365a9 | ||
|
|
63d93d1733 | ||
|
|
c5a3bc39b1 | ||
|
|
9dbb372610 | ||
|
|
228e836344 | ||
|
|
ed23489f42 | ||
|
|
457e2288c9 | ||
|
|
e8ec7ab058 | ||
|
|
1a03cf47f6 | ||
|
|
b97ebdc98f | ||
|
|
2098fd6169 | ||
|
|
ab6120cde5 | ||
|
|
c3c1505392 | ||
|
|
05e141a6b3 | ||
|
|
aab68217b7 | ||
|
|
a95a11e5b8 | ||
|
|
5cbfb18075 | ||
|
|
beb42fffa4 | ||
|
|
660b1b4bdc | ||
|
|
c20c44514a | ||
|
|
6118c043b1 | ||
|
|
5f0ab726f7 | ||
|
|
e82aaf2587 | ||
|
|
27aef3dd91 | ||
|
|
45155597aa | ||
|
|
80afa33aad | ||
|
|
b42c7fa5b8 | ||
|
|
d77599234e | ||
|
|
41a63be28e | ||
|
|
098705a29e | ||
|
|
683c5acb90 | ||
|
|
b1d5f5b449 | ||
|
|
4b221b7f1e | ||
|
|
59237bfbbc | ||
|
|
1cbc846eba | ||
|
|
3142f1dbb9 | ||
|
|
b5c4227dc6 | ||
|
|
d6a5094004 | ||
|
|
7b95ea5d11 | ||
|
|
bdc9c743a5 | ||
|
|
739393beeb | ||
|
|
fc2b0053ff | ||
|
|
7b8443ac78 | ||
|
|
5d56effdee | ||
|
|
52e5f0a5c1 |
@@ -12,6 +12,8 @@ body:
|
||||
after recreating the CMake build directory and with `-DGGML_CCACHE=OFF`.
|
||||
If the compilation succeeds with ccache disabled you should be able to permanently fix the issue
|
||||
by clearing `~/.cache/ccache` (on Linux).
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
- type: textarea
|
||||
id: commit
|
||||
attributes:
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/011-bug-results.yml
vendored
4
.github/ISSUE_TEMPLATE/011-bug-results.yml
vendored
@@ -1,5 +1,5 @@
|
||||
name: Bug (model use)
|
||||
description: Something goes wrong when using a model (in general, not specific to a single llama.cpp module).
|
||||
description: Something goes wrong when running a model (crashes, garbled outputs, etc.).
|
||||
title: "Eval bug: "
|
||||
labels: ["bug-unconfirmed", "model evaluation"]
|
||||
body:
|
||||
@@ -12,6 +12,8 @@ body:
|
||||
If you encountered the issue while using an external UI (e.g. ollama),
|
||||
please reproduce your issue using one of the examples/binaries in this repository.
|
||||
The `llama-completion` binary can be used for simple and reproducible model inference.
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
- type: textarea
|
||||
id: version
|
||||
attributes:
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/019-bug-misc.yml
vendored
2
.github/ISSUE_TEMPLATE/019-bug-misc.yml
vendored
@@ -10,6 +10,8 @@ body:
|
||||
This issue template is intended for miscellaneous bugs that don't fit into any other category.
|
||||
If you encountered the issue while using an external UI (e.g. ollama),
|
||||
please reproduce your issue using one of the examples/binaries in this repository.
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
- type: textarea
|
||||
id: version
|
||||
attributes:
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/020-enhancement.yml
vendored
2
.github/ISSUE_TEMPLATE/020-enhancement.yml
vendored
@@ -8,6 +8,8 @@ body:
|
||||
value: |
|
||||
[Please post your idea first in Discussion if there is not yet a consensus for this enhancement request. This will help to keep this issue tracker focused on enhancements that the community has agreed needs to be implemented.](https://github.com/ggml-org/llama.cpp/discussions/categories/ideas)
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
|
||||
- type: checkboxes
|
||||
id: prerequisites
|
||||
attributes:
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/030-research.yml
vendored
2
.github/ISSUE_TEMPLATE/030-research.yml
vendored
@@ -8,6 +8,8 @@ body:
|
||||
value: |
|
||||
Don't forget to check for any [duplicate research issue tickets](https://github.com/ggml-org/llama.cpp/issues?q=is%3Aopen+is%3Aissue+label%3A%22research+%F0%9F%94%AC%22)
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
|
||||
- type: checkboxes
|
||||
id: research-stage
|
||||
attributes:
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/040-refactor.yml
vendored
2
.github/ISSUE_TEMPLATE/040-refactor.yml
vendored
@@ -9,6 +9,8 @@ body:
|
||||
Don't forget to [check for existing refactor issue tickets](https://github.com/ggml-org/llama.cpp/issues?q=is%3Aopen+is%3Aissue+label%3Arefactoring) in case it's already covered.
|
||||
Also you may want to check [Pull request refactor label as well](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Aopen+is%3Apr+label%3Arefactoring) for duplicates too.
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
|
||||
- type: textarea
|
||||
id: background-description
|
||||
attributes:
|
||||
|
||||
2
.github/workflows/gguf-publish.yml
vendored
2
.github/workflows/gguf-publish.yml
vendored
@@ -29,10 +29,10 @@ jobs:
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
pip-install: poetry==2.4.0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
cd gguf-py
|
||||
python -m pip install poetry==2.3.2
|
||||
poetry install
|
||||
|
||||
- name: Build package
|
||||
|
||||
2
.github/workflows/python-type-check.yml
vendored
2
.github/workflows/python-type-check.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
pip-install: -r requirements/requirements-all.txt ty==0.0.26
|
||||
pip-install: -r requirements/requirements-all.txt ty==0.0.33
|
||||
# - name: Type-check with Pyright
|
||||
# uses: jakebailey/pyright-action@v2
|
||||
# with:
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -105,6 +105,8 @@
|
||||
__pycache__/
|
||||
*/poetry.lock
|
||||
poetry.toml
|
||||
poetry.lock
|
||||
uv.lock
|
||||
|
||||
# Nix
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ General:
|
||||
- By very precise and concise when writing code, comments, explanations, etc.
|
||||
- PR and commit titles format: `<module> : <title>`. Lookup recents for examples
|
||||
- Don't try to build or run the code unless you are explicitly asked to do so
|
||||
- Use the `gh` CLI tool when querying PRs, issues, or other GitHub resources
|
||||
|
||||
Coding:
|
||||
- When in doubt, always refer to the CONTRIBUTING.md file of the project
|
||||
|
||||
@@ -76,6 +76,7 @@
|
||||
/ggml/src/ggml-vulkan/ @ggml-org/ggml-vulkan
|
||||
/ggml/src/ggml-webgpu/ @ggml-org/ggml-webgpu
|
||||
/ggml/src/ggml-zdnn/ @ggml-org/ggml-zdnn @Andreas-Krebbel @AlekseiNikiforovIBM
|
||||
/ggml/src/ggml-zendnn/ @avinashcpandey @Jiten1parmar @z-vishal
|
||||
/ggml/src/ggml.c @ggerganov
|
||||
/ggml/src/ggml.cpp @ggerganov
|
||||
/ggml/src/gguf.cpp @JohannesGaessler @Green-Sky
|
||||
|
||||
@@ -248,6 +248,8 @@ std::vector<std::string> common_arg::get_env() const {
|
||||
|
||||
// Helper function to parse tensor buffer override strings
|
||||
static void parse_tensor_buffer_overrides(const std::string & value, std::vector<llama_model_tensor_buft_override> & overrides) {
|
||||
ggml_backend_load_all();
|
||||
|
||||
std::map<std::string, ggml_backend_buffer_type_t> buft_list;
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
auto * dev = ggml_backend_dev_get(i);
|
||||
@@ -425,6 +427,10 @@ static bool parse_bool_value(const std::string & value) {
|
||||
}
|
||||
}
|
||||
|
||||
[[noreturn]] static void arg_removed(const std::string & msg) {
|
||||
throw std::invalid_argument("the argument has been removed. " + msg);
|
||||
}
|
||||
|
||||
//
|
||||
// CLI argument parsing functions
|
||||
//
|
||||
@@ -803,6 +809,7 @@ static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & val
|
||||
if (dev_names.size() == 1 && dev_names[0] == "none") {
|
||||
devices.push_back(nullptr);
|
||||
} else {
|
||||
ggml_backend_load_all();
|
||||
for (const auto & device : dev_names) {
|
||||
auto * dev = ggml_backend_dev_by_name(device.c_str());
|
||||
if (!dev || ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) {
|
||||
@@ -820,6 +827,7 @@ static void add_rpc_devices(const std::string & servers) {
|
||||
if (rpc_servers.empty()) {
|
||||
throw std::invalid_argument("no RPC servers specified");
|
||||
}
|
||||
ggml_backend_load_all();
|
||||
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
|
||||
if (!rpc_reg) {
|
||||
throw std::invalid_argument("failed to find RPC backend");
|
||||
@@ -1016,9 +1024,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
|
||||
params.use_color = tty_can_use_colors();
|
||||
|
||||
// load dynamic backends
|
||||
ggml_backend_load_all();
|
||||
|
||||
common_params_context ctx_arg(params);
|
||||
ctx_arg.print_usage = print_usage;
|
||||
ctx_arg.ex = ex;
|
||||
@@ -2275,6 +2280,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--list-devices"},
|
||||
"print list of available devices and exit",
|
||||
[](common_params &) {
|
||||
ggml_backend_load_all();
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
auto * dev = ggml_backend_dev_get(i);
|
||||
@@ -2864,7 +2870,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--tools"}, "TOOL1,TOOL2,...",
|
||||
"experimental: whether to enable built-in tools for AI agents - do not enable in untrusted environments (default: no tools)\n"
|
||||
"specify \"all\" to enable all tools\n"
|
||||
"available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff",
|
||||
"available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff, get_datetime",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.server_tools = parse_csv_row(value);
|
||||
}
|
||||
@@ -3380,7 +3386,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-draft-poll", "--poll-draft"}, "<0|1>",
|
||||
"Use polling to wait for draft model work (default: same as --poll])",
|
||||
"Use polling to wait for draft model work (default: same as --poll)",
|
||||
[](common_params & params, int value) {
|
||||
params.speculative.draft.cpuparams.poll = value;
|
||||
}
|
||||
@@ -3499,7 +3505,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_N_MIN"));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--spec--draft-p-split", "--draft-p-split"}, "P",
|
||||
{"--spec-draft-p-split", "--draft-p-split"}, "P",
|
||||
string_format("speculative decoding split probability (default: %.2f)", (double)params.speculative.draft.p_split),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.draft.p_split = std::stof(value);
|
||||
@@ -3715,35 +3721,35 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--draft", "--draft-n", "--draft-max"}, "N",
|
||||
"the argument has been removed. use --spec-draft-n-max or --spec-ngram-mod-n-max",
|
||||
[](common_params & /*params*/, int /*value*/) {
|
||||
throw std::invalid_argument("the argument has been removed. use --spec-draft-n-max or --spec-ngram-mod-n-max");
|
||||
arg_removed("use --spec-draft-n-max or --spec-ngram-mod-n-max");
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_MAX"));
|
||||
add_opt(common_arg(
|
||||
{"--draft-min", "--draft-n-min"}, "N",
|
||||
"the argument has been removed. use --spec-draft-n-min or --spec-ngram-mod-n-min",
|
||||
[](common_params & /*params*/, int /*value*/) {
|
||||
throw std::invalid_argument("the argument has been removed. use --spec-draft-n-min or --spec-ngram-mod-n-min");
|
||||
arg_removed("use --spec-draft-n-min or --spec-ngram-mod-n-min");
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_MIN"));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-size-n"}, "N",
|
||||
"the argument has been removed. use the respective --spec-ngram-*-size-n or --spec-ngram-mod-n-match",
|
||||
[](common_params & /*params*/, int /*value*/) {
|
||||
throw std::invalid_argument("the argument has been removed. use the respective --spec-ngram-*-size-n");
|
||||
arg_removed("use the respective --spec-ngram-*-size-n");
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-size-m"}, "N",
|
||||
"the argument has been removed. use the respective --spec-ngram-*-size-m",
|
||||
[](common_params & /*params*/, int /*value*/) {
|
||||
throw std::invalid_argument("the argument has been removed. use the respective --spec-ngram-*-size-m");
|
||||
arg_removed("use the respective --spec-ngram-*-size-m");
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-min-hits"}, "N",
|
||||
"the argument has been removed. use the respective --spec-ngram-*-min-hits",
|
||||
[](common_params & /*params*/, int /*value*/) {
|
||||
throw std::invalid_argument("the argument has been removed. use the respective --spec-ngram-*-min-hits");
|
||||
arg_removed("use the respective --spec-ngram-*-min-hits");
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
@@ -3794,7 +3800,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-algorithm"}, "N",
|
||||
string_format("diffusion algorithm: 0=ORIGIN, 1=ENTROPY_BASED, 2=MARGIN_BASED, 3=RANDOM, 4=LOW_CONFIDENCE (default: %d)", params.diffusion.algorithm),
|
||||
string_format(
|
||||
"diffusion algorithm: 0=DIFFUSION_ALGORITHM_ORIGIN, 1=DIFFUSION_ALGORITHM_ENTROPY_BASED, "
|
||||
"2=DIFFUSION_ALGORITHM_MARGIN_BASED, 3=DIFFUSION_ALGORITHM_RANDOM, "
|
||||
"4=DIFFUSION_ALGORITHM_CONFIDENCE_BASED (default: %d)", params.diffusion.algorithm),
|
||||
[](common_params & params, int value) { params.diffusion.algorithm = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
|
||||
@@ -136,10 +136,10 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co
|
||||
if (!end.empty()) {
|
||||
if (!start.empty()) {
|
||||
// Standard tag-based: optional(<think>reasoning</think>)
|
||||
return p.optional(start + p.reasoning(p.until(end)) + end + p.space());
|
||||
return p.optional(p.optspace(start) + p.reasoning(p.until(trim_whitespace(end))) + p.optspace(end));
|
||||
}
|
||||
// Delimiter-style (empty start)
|
||||
return p.optional(p.reasoning(p.until(end)) + end + p.space());
|
||||
return p.optional(p.reasoning(p.until(trim_whitespace(end))) + p.optspace(end));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -186,7 +186,6 @@ common_peg_parser analyze_tools::build_parser(parser_build_context & ctx) const
|
||||
common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
const auto & inputs = ctx.inputs;
|
||||
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
// Build effective field names with dot notation if function_field is set
|
||||
std::string name_field = format.name_field;
|
||||
@@ -225,8 +224,7 @@ common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_cont
|
||||
tool_start = format.per_call_start;
|
||||
}
|
||||
|
||||
return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(p.until(tool_start)))) + tools_parser +
|
||||
p.end();
|
||||
return ctx.reasoning_parser + p.optional(p.content(p.until(tool_start))) + tools_parser + p.end();
|
||||
}
|
||||
|
||||
common_peg_parser analyze_tools::build_func_parser(common_chat_peg_builder & p, const std::string & name,
|
||||
@@ -270,7 +268,6 @@ common_peg_parser analyze_tools::build_func_parser(common_chat_peg_builder & p,
|
||||
common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
const auto & inputs = ctx.inputs;
|
||||
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
common_peg_parser tool_choice = p.choice();
|
||||
|
||||
@@ -336,14 +333,12 @@ common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context
|
||||
|
||||
std::string trigger_marker = !format.section_start.empty() ? format.section_start : format.per_call_start;
|
||||
auto content_before_tools = trigger_marker.empty() ? p.eps() : p.until(trigger_marker);
|
||||
return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(content_before_tools))) + tool_calls +
|
||||
p.end();
|
||||
return ctx.reasoning_parser + p.optional(p.content(content_before_tools)) + tool_calls + p.end();
|
||||
}
|
||||
|
||||
common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
const auto & inputs = ctx.inputs;
|
||||
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
auto until_suffix = p.rule("until-suffix", p.until(arguments.value_suffix));
|
||||
|
||||
@@ -471,8 +466,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
|
||||
std::string trigger_marker = !format.section_start.empty() ? format.section_start : format.per_call_start;
|
||||
auto content_before_tools = trigger_marker.empty() ? p.eps() : p.until(trigger_marker);
|
||||
return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(content_before_tools))) + tool_calls +
|
||||
p.end();
|
||||
return ctx.reasoning_parser + p.optional(p.content(content_before_tools)) + tool_calls + p.end();
|
||||
}
|
||||
|
||||
} // namespace autoparser
|
||||
|
||||
@@ -342,7 +342,7 @@ void analyze_reasoning::compare_thinking_enabled() {
|
||||
if (left_trimmed.empty() && !diff.right.empty()) {
|
||||
if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) {
|
||||
if (start.empty()) {
|
||||
start = trim_leading_whitespace(diff.right);
|
||||
start = diff.right;
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
}
|
||||
}
|
||||
@@ -353,7 +353,7 @@ void analyze_reasoning::compare_thinking_enabled() {
|
||||
if (seg.size() >= 2 && seg[seg.size() - 1].value == left_trimmed && seg[seg.size() - 2].type == segment_type::MARKER) {
|
||||
start = seg[seg.size() - 2].value;
|
||||
}
|
||||
end = trim_trailing_whitespace(diff.left);
|
||||
end = diff.left;
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
}
|
||||
}
|
||||
@@ -445,14 +445,14 @@ void analyze_reasoning::compare_reasoning_scope() {
|
||||
auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B);
|
||||
if (result.result.success()) {
|
||||
start = result.tags["pre"];
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
end = result.tags["post"];
|
||||
} else {
|
||||
auto parser_delimiter = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space())));
|
||||
});
|
||||
result = parser_delimiter.parse_anywhere_and_extract(comparison->output_B);
|
||||
if (result.result.success()) {
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
end = result.tags["post"];
|
||||
} else {
|
||||
LOG_DBG(ANSI_ORANGE "%s: Unable to extract reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__);
|
||||
mode = reasoning_mode::NONE;
|
||||
|
||||
@@ -816,6 +816,32 @@ common_peg_parser common_chat_peg_builder::prefix(const std::string & s, const s
|
||||
return literal(s.substr(0, s.rfind(delimiter)));
|
||||
}
|
||||
|
||||
common_peg_parser common_chat_peg_builder::optspace(const std::string & tag) {
|
||||
auto parser = eps();
|
||||
size_t end_of_prefix_space = tag.size();
|
||||
size_t start_of_suffix_space = tag.size();
|
||||
for (size_t i = 0; i < tag.size(); i++) {
|
||||
if (!std::isspace(tag[i])) {
|
||||
end_of_prefix_space = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (size_t i = tag.size(); i > 0; i--) {
|
||||
if (!std::isspace(tag[i - 1])) {
|
||||
start_of_suffix_space = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < end_of_prefix_space; i++) {
|
||||
parser += optional(literal(std::string(1, tag[i])));
|
||||
}
|
||||
parser += literal(tag.substr(end_of_prefix_space, start_of_suffix_space - end_of_prefix_space));
|
||||
for (size_t i = start_of_suffix_space; i < tag.size(); i++) {
|
||||
parser += optional(literal(std::string(1, tag[i])));
|
||||
}
|
||||
return parser;
|
||||
}
|
||||
|
||||
common_peg_parser common_chat_peg_builder::standard_json_tools(
|
||||
const std::string & section_start,
|
||||
const std::string & section_end,
|
||||
|
||||
@@ -96,6 +96,9 @@ class common_chat_peg_builder : public common_peg_parser_builder {
|
||||
// Return a parser that parses the prefix of a string, up to a given delimiter.
|
||||
common_peg_parser prefix(const std::string & s, const std::string & delimiter = {});
|
||||
|
||||
// Return a parser that parses all elements of tag, but leading and trailing spaces are optional
|
||||
common_peg_parser optspace(const std::string & tag);
|
||||
|
||||
// Legacy-compatible helper for building standard JSON tool calls
|
||||
// Used by tests and manual parsers
|
||||
// name_key/args_key: JSON key names for function name and arguments
|
||||
|
||||
@@ -2116,22 +2116,38 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::string common_chat_templates_generation_prompt(const common_chat_template & tmpl, const autoparser::generation_params & inputs) {
|
||||
autoparser::generation_params params = inputs;
|
||||
params.add_generation_prompt = false;
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
params.add_generation_prompt = true;
|
||||
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
|
||||
size_t prefix_len = 0;
|
||||
size_t min_size = std::min(no_gen_prompt.size(), gen_prompt.size());
|
||||
while (prefix_len < min_size && no_gen_prompt[prefix_len] == gen_prompt[prefix_len]) {
|
||||
prefix_len++;
|
||||
}
|
||||
return gen_prompt.substr(prefix_len);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs) {
|
||||
autoparser::generation_params params;
|
||||
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
|
||||
const auto & tmpl =
|
||||
params.tools.is_array() && tmpls->template_tool_use ? *tmpls->template_tool_use : *tmpls->template_default;
|
||||
const auto & src = tmpl.source();
|
||||
const auto & caps = tmpl.original_caps();
|
||||
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.reasoning_format = inputs.reasoning_format;
|
||||
params.enable_thinking = inputs.enable_thinking;
|
||||
params.grammar = inputs.grammar;
|
||||
params.now = inputs.now;
|
||||
params.add_bos = tmpls->add_bos;
|
||||
params.add_eos = tmpls->add_eos;
|
||||
const auto & src = tmpl.source();
|
||||
const auto & caps = tmpl.original_caps();
|
||||
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.reasoning_format = inputs.reasoning_format;
|
||||
params.enable_thinking = inputs.enable_thinking;
|
||||
params.grammar = inputs.grammar;
|
||||
params.now = inputs.now;
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
params.add_bos = tmpls->add_bos;
|
||||
params.add_eos = tmpls->add_eos;
|
||||
|
||||
if (src.find("<|channel|>") == std::string::npos) {
|
||||
// map developer to system for all models except for GPT-OSS
|
||||
@@ -2153,14 +2169,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
workaround::func_args_not_string(params.messages);
|
||||
}
|
||||
|
||||
params.add_generation_prompt = false;
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
params.add_generation_prompt = true;
|
||||
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
|
||||
params.generation_prompt = diff.right + diff.suffix;
|
||||
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
params.generation_prompt = common_chat_templates_generation_prompt(tmpl, params);
|
||||
|
||||
params.extra_context = common_chat_extra_context();
|
||||
for (auto el : inputs.chat_template_kwargs) {
|
||||
@@ -2212,8 +2221,8 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
|
||||
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
|
||||
if (auto_params.supports_thinking) {
|
||||
auto_params.thinking_start_tag = autoparser.reasoning.start;
|
||||
auto_params.thinking_end_tag = autoparser.reasoning.end;
|
||||
auto_params.thinking_start_tag = trim_whitespace(autoparser.reasoning.start);
|
||||
auto_params.thinking_end_tag = trim_whitespace(autoparser.reasoning.end);
|
||||
}
|
||||
auto_params.generation_prompt = params.generation_prompt;
|
||||
common_peg_arena arena;
|
||||
|
||||
@@ -109,16 +109,24 @@ static std::vector<llama_device_memory_data> common_get_device_memory_data(
|
||||
ret.back().total = total;
|
||||
}
|
||||
for (size_t i = 0; i < nd; i++) {
|
||||
ggml_backend_dev_t dev = llama_model_get_device(model, i);
|
||||
|
||||
size_t free;
|
||||
size_t total;
|
||||
ggml_backend_dev_memory(llama_model_get_device(model, i), &free, &total);
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
|
||||
// devices can return 0 bytes for free and total memory if they do not
|
||||
// have any to report. in this case, we will use the host memory as a fallback
|
||||
// fixes: https://github.com/ggml-org/llama.cpp/issues/18577
|
||||
// Some non-GPU accelerator backends, such as BLAS, report 0/0 and rely on
|
||||
// the host-memory fallback. For GPU-like backends, keep 0/0 so --fit does
|
||||
// not assign anything to a device with an unknown memory budget.
|
||||
if (free == 0 && total == 0) {
|
||||
free = ret.back().free;
|
||||
total = ret.back().total;
|
||||
const enum ggml_backend_dev_type type = ggml_backend_dev_type(dev);
|
||||
if (type == GGML_BACKEND_DEVICE_TYPE_GPU || type == GGML_BACKEND_DEVICE_TYPE_IGPU) {
|
||||
LOG_WRN("%s: device %s did not report memory; --fit will not use it\n",
|
||||
__func__, ggml_backend_dev_name(dev));
|
||||
} else {
|
||||
free = ret.back().free;
|
||||
total = ret.back().total;
|
||||
}
|
||||
}
|
||||
ret[i].free = free;
|
||||
ret[i].total = total;
|
||||
|
||||
@@ -57,7 +57,7 @@ static fs::path get_cache_directory() {
|
||||
#ifndef _WIN32
|
||||
const struct passwd * pw = getpwuid(getuid());
|
||||
|
||||
if (pw->pw_dir && *pw->pw_dir) {
|
||||
if (pw && pw->pw_dir && *pw->pw_dir) {
|
||||
return fs::path(pw->pw_dir) / ".cache" / "huggingface" / "hub";
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -49,7 +49,7 @@ enum common_log_col : int {
|
||||
};
|
||||
|
||||
// disable colors by default
|
||||
static std::vector<const char *> g_col = {
|
||||
static const char* g_col[] = {
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
@@ -247,7 +247,6 @@ public:
|
||||
|
||||
entries = std::move(new_entries);
|
||||
}
|
||||
|
||||
cv.notify_one();
|
||||
}
|
||||
|
||||
@@ -265,7 +264,6 @@ public:
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
cv.wait(lock, [this]() { return head != tail; });
|
||||
|
||||
cur = entries[head];
|
||||
|
||||
head = (head + 1) % entries.size();
|
||||
@@ -301,7 +299,6 @@ public:
|
||||
|
||||
tail = (tail + 1) % entries.size();
|
||||
}
|
||||
|
||||
cv.notify_one();
|
||||
}
|
||||
|
||||
@@ -338,7 +335,7 @@ public:
|
||||
g_col[COMMON_LOG_COL_CYAN] = LOG_COL_CYAN;
|
||||
g_col[COMMON_LOG_COL_WHITE] = LOG_COL_WHITE;
|
||||
} else {
|
||||
for (size_t i = 0; i < g_col.size(); i++) {
|
||||
for (size_t i = 0; i < std::size(g_col); i++) {
|
||||
g_col[i] = "";
|
||||
}
|
||||
}
|
||||
@@ -368,14 +365,20 @@ struct common_log * common_log_init() {
|
||||
}
|
||||
|
||||
struct common_log * common_log_main() {
|
||||
static struct common_log log;
|
||||
// We intentionally leak (i.e. do not delete) the logger singleton because
|
||||
// common_log destructor called at DLL teardown phase will cause hanging on Windows.
|
||||
// OS will release resources anyway so it should not be a significant issue,
|
||||
// though this design may cause logs to be lost if not flushed before the program exits.
|
||||
// Refer to https://github.com/ggml-org/llama.cpp/issues/22142 for details.
|
||||
static struct common_log * log;
|
||||
static std::once_flag init_flag;
|
||||
std::call_once(init_flag, [&]() {
|
||||
log = new common_log;
|
||||
// Set default to auto-detect colors
|
||||
log.set_colors(tty_can_use_colors());
|
||||
log->set_colors(tty_can_use_colors());
|
||||
});
|
||||
|
||||
return &log;
|
||||
return log;
|
||||
}
|
||||
|
||||
void common_log_pause(struct common_log * log) {
|
||||
|
||||
@@ -49,7 +49,11 @@ void common_log_default_callback(enum ggml_log_level level, const char * text, v
|
||||
struct common_log;
|
||||
|
||||
struct common_log * common_log_init();
|
||||
struct common_log * common_log_main(); // singleton, automatically destroys itself on exit
|
||||
|
||||
// Singleton, intentionally leaked to avoid Windows teardown hangs.
|
||||
// Call common_log_flush() before exit if you want to ensure all logs are flushed.
|
||||
struct common_log * common_log_main();
|
||||
|
||||
void common_log_pause (struct common_log * log); // pause the worker thread, not thread-safe
|
||||
void common_log_resume(struct common_log * log); // resume the worker thread, not thread-safe
|
||||
void common_log_free (struct common_log * log);
|
||||
|
||||
@@ -122,6 +122,20 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
}
|
||||
break;
|
||||
case REASONING_BUDGET_DONE:
|
||||
// Re-arm on a new start tag: some models emit multiple <think> blocks
|
||||
// per response, and each should get a fresh budget window.
|
||||
if (ctx->start_matcher.advance(token)) {
|
||||
ctx->state = REASONING_BUDGET_COUNTING;
|
||||
ctx->remaining = ctx->budget;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: re-activated on new start tag, budget=%d tokens\n", ctx->budget);
|
||||
|
||||
if (ctx->remaining <= 0) {
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -144,6 +158,8 @@ static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_tok
|
||||
for (size_t i = 0; i < cur_p->size; i++) {
|
||||
if (cur_p->data[i].id != forced) {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
} else {
|
||||
cur_p->data[i].logit = +INFINITY; // force the token
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -218,34 +234,6 @@ static struct llama_sampler * common_reasoning_budget_init_state(
|
||||
);
|
||||
}
|
||||
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
const std::vector<llama_token> & prefill_tokens) {
|
||||
// Determine initial state from prefill: COUNTING if the prefill begins with
|
||||
// the start sequence but does not also contain the end sequence after it.
|
||||
common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE;
|
||||
if (!prefill_tokens.empty() && !start_tokens.empty() &&
|
||||
prefill_tokens.size() >= start_tokens.size() &&
|
||||
std::equal(start_tokens.begin(), start_tokens.end(), prefill_tokens.begin())) {
|
||||
initial_state = REASONING_BUDGET_COUNTING;
|
||||
// If the end sequence also follows the start in the prefill, reasoning
|
||||
// was opened and immediately closed — stay IDLE.
|
||||
if (!end_tokens.empty() &&
|
||||
prefill_tokens.size() >= start_tokens.size() + end_tokens.size()) {
|
||||
auto end_start = prefill_tokens.end() - (ptrdiff_t) end_tokens.size();
|
||||
if (end_start >= prefill_tokens.begin() + (ptrdiff_t) start_tokens.size() &&
|
||||
std::equal(end_tokens.begin(), end_tokens.end(), end_start)) {
|
||||
initial_state = REASONING_BUDGET_IDLE;
|
||||
}
|
||||
}
|
||||
}
|
||||
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
|
||||
}
|
||||
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
|
||||
@@ -29,10 +29,7 @@ enum common_reasoning_budget_state {
|
||||
// end_tokens - token sequence for natural deactivation
|
||||
// forced_tokens - token sequence forced when budget expires
|
||||
// budget - max tokens allowed in the reasoning block
|
||||
// prefill_tokens - tokens already present in the prompt (generation prompt);
|
||||
// used to determine the initial state: COUNTING if they begin
|
||||
// with start_tokens (but don't also end with end_tokens),
|
||||
// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING.
|
||||
// initial_state - initial state
|
||||
//
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
@@ -40,16 +37,6 @@ struct llama_sampler * common_reasoning_budget_init(
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
const std::vector<llama_token> & prefill_tokens = {});
|
||||
|
||||
// Variant that takes an explicit initial state (used by tests and clone).
|
||||
// COUNTING with budget <= 0 is promoted to FORCING.
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
common_reasoning_budget_state initial_state);
|
||||
common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE);
|
||||
|
||||
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl);
|
||||
|
||||
@@ -260,32 +260,35 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
}
|
||||
}
|
||||
|
||||
// Compute prefill tokens from the generation prompt
|
||||
std::vector<llama_token> prefill_tokens;
|
||||
if (!params.generation_prompt.empty()) {
|
||||
GGML_ASSERT(vocab != nullptr);
|
||||
auto tokens = common_tokenize(vocab, params.generation_prompt, false, true);
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
std::string piece = common_token_to_piece(vocab, tokens[i], true);
|
||||
if (i == 0 && std::isspace(piece[0]) && !std::isspace(params.generation_prompt[0])) {
|
||||
// Some tokenizers will add a space before the first special token, need to exclude
|
||||
continue;
|
||||
}
|
||||
LOG_DBG("%s: prefill token: %d = %s\n", __func__, tokens[i], piece.c_str());
|
||||
prefill_tokens.push_back(tokens[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Feed generation prompt tokens to the grammar sampler so it advances past
|
||||
// tokens the template already placed in the prompt.
|
||||
// Only applies to output-format and tool-call grammars; user-supplied grammars must not be prefilled.
|
||||
std::vector<llama_token> prefill_tokens;
|
||||
if (!params.generation_prompt.empty() && common_grammar_needs_prefill(params.grammar)) {
|
||||
GGML_ASSERT(vocab != nullptr);
|
||||
prefill_tokens = common_tokenize(vocab, params.generation_prompt, false, true);
|
||||
if (!prefill_tokens.empty()) {
|
||||
std::string first_token = common_token_to_piece(vocab, prefill_tokens[0], true);
|
||||
if (std::isspace(first_token[0]) && !std::isspace(params.generation_prompt[0])) {
|
||||
// Some tokenizers will add a space before the first special token, need to remove
|
||||
prefill_tokens = std::vector<llama_token>(prefill_tokens.begin() + 1, prefill_tokens.end());
|
||||
}
|
||||
}
|
||||
|
||||
if (grmr && !params.grammar_lazy) {
|
||||
try {
|
||||
for (const auto & token : prefill_tokens) {
|
||||
llama_sampler_accept(grmr, token);
|
||||
LOG_DBG("%s: accepted prefill token (%d)\n", __func__, token);
|
||||
}
|
||||
} catch (std::exception &e) {
|
||||
LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__,
|
||||
common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str());
|
||||
throw e;
|
||||
if (grmr && !params.grammar_lazy && common_grammar_needs_prefill(params.grammar)) {
|
||||
try {
|
||||
for (const auto & token : prefill_tokens) {
|
||||
llama_sampler_accept(grmr, token);
|
||||
LOG_DBG("%s: grammar accepted prefill token (%d)\n", __func__, token);
|
||||
}
|
||||
} catch (std::exception &e) {
|
||||
LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__,
|
||||
common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -296,8 +299,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
params.reasoning_budget_start,
|
||||
params.reasoning_budget_end,
|
||||
params.reasoning_budget_forced,
|
||||
params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens,
|
||||
prefill_tokens);
|
||||
params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens);
|
||||
|
||||
for (const auto & token : prefill_tokens) {
|
||||
llama_sampler_accept(rbudget, token);
|
||||
LOG_DBG("%s: reasoning-budget accepted prefill token (%d)\n", __func__, token);
|
||||
}
|
||||
}
|
||||
|
||||
if (params.has_logit_bias()) {
|
||||
@@ -431,7 +438,7 @@ static bool grammar_should_apply(struct common_sampler * gsmpl) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool is_generated) {
|
||||
if (!gsmpl) {
|
||||
return;
|
||||
}
|
||||
@@ -439,9 +446,11 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
// grammar_should_apply() checks the reasoning budget state, so calculate this before we accept
|
||||
accept_grammar = accept_grammar && grammar_should_apply(gsmpl);
|
||||
const auto accept_grammar = is_generated && grammar_should_apply(gsmpl);
|
||||
|
||||
llama_sampler_accept(gsmpl->rbudget, token);
|
||||
if (gsmpl->rbudget && is_generated) {
|
||||
llama_sampler_accept(gsmpl->rbudget, token);
|
||||
}
|
||||
|
||||
if (gsmpl->grmr && accept_grammar) {
|
||||
llama_sampler_accept(gsmpl->grmr, token);
|
||||
|
||||
@@ -41,8 +41,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
|
||||
void common_sampler_free(struct common_sampler * gsmpl);
|
||||
|
||||
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar);
|
||||
// if is_generated is true, the token is accepted by the sampling chain, the reasoning budget sampler, and the grammar sampler
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool is_generated);
|
||||
void common_sampler_reset (struct common_sampler * gsmpl);
|
||||
struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
|
||||
|
||||
|
||||
@@ -167,8 +167,6 @@ struct common_speculative_checkpoint {
|
||||
size_t size() const {
|
||||
return data.size();
|
||||
}
|
||||
|
||||
size_t ckpt_size = 0;
|
||||
};
|
||||
|
||||
struct common_speculative_state_draft : public common_speculative_state {
|
||||
@@ -176,7 +174,7 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
llama_context * ctx_dft;
|
||||
|
||||
bool use_ckpt = false;
|
||||
struct common_speculative_checkpoint ckpt;
|
||||
common_speculative_checkpoint ckpt;
|
||||
|
||||
common_sampler * smpl;
|
||||
|
||||
@@ -249,29 +247,19 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
|
||||
void begin(const llama_tokens & prompt) override {
|
||||
if (use_ckpt && ckpt.size() > 0) {
|
||||
// delete checkpoint
|
||||
LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n",
|
||||
__func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024);
|
||||
ckpt.pos_min = 0;
|
||||
ckpt.pos_max = 0;
|
||||
ckpt.n_tokens = 0;
|
||||
ckpt.ckpt_size = 0;
|
||||
ckpt.data.clear();
|
||||
}
|
||||
void begin(const llama_tokens & /*prompt*/) override {
|
||||
}
|
||||
|
||||
size_t draft_create_checkpoint(int n_tokens_prompt, int n_tokens_batch) {
|
||||
size_t create_checkpoint(int n_tokens_prompt) {
|
||||
int slot_id = 0;
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
|
||||
ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id);
|
||||
ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id);
|
||||
ckpt.n_tokens = n_tokens_prompt - n_tokens_batch;
|
||||
ckpt.n_tokens = n_tokens_prompt;
|
||||
ckpt.data.resize(checkpoint_size);
|
||||
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (n != checkpoint_size) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
|
||||
}
|
||||
@@ -281,13 +269,13 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
return n;
|
||||
}
|
||||
|
||||
size_t draft_restore_checkpoint(size_t ckpt_size_part_expected) {
|
||||
size_t restore_checkpoint() {
|
||||
int slot_id = 0;
|
||||
LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max);
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
if (n != ckpt_size_part_expected) {
|
||||
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
|
||||
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n);
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (n != ckpt.size()) {
|
||||
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu",
|
||||
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size());
|
||||
}
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1);
|
||||
|
||||
@@ -346,13 +334,18 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
|
||||
const int i_start = std::max<int>(0, (int) prompt_cur.size() - n_ctx);
|
||||
|
||||
if (use_ckpt && i_start > 0) {
|
||||
LOG_WRN("%s: context shift is not supported with checkpoint-based contexts - skipping\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
// reuse as much as possible from the old draft context
|
||||
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
|
||||
for (int i = 0; i < (int) prompt_dft.size(); ++i) {
|
||||
int cur = 0;
|
||||
while (i_start + cur < (int) prompt_cur.size() &&
|
||||
i + cur < (int) prompt_dft.size() &&
|
||||
prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
|
||||
i + cur < (int) prompt_dft.size() &&
|
||||
prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
|
||||
cur++;
|
||||
}
|
||||
|
||||
@@ -360,21 +353,26 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
reuse_i = i;
|
||||
reuse_n = cur;
|
||||
}
|
||||
|
||||
if (use_ckpt) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n",
|
||||
__func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size());
|
||||
if (use_ckpt && ckpt.ckpt_size == 0 && reuse_n > 0) {
|
||||
LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n",
|
||||
__func__, reuse_i, reuse_n);
|
||||
if (use_ckpt && ckpt.n_tokens > reuse_n) {
|
||||
LOG_DBG("%s: checkpoint (n_tokens = %d) is outdated -> delete it\n", __func__, (int) ckpt.n_tokens);
|
||||
|
||||
reuse_i = 0;
|
||||
reuse_n = 0;
|
||||
|
||||
ckpt = {};
|
||||
}
|
||||
|
||||
result.clear();
|
||||
result.reserve(sparams.n_max);
|
||||
|
||||
bool needs_ckpt = use_ckpt && prompt_dft.size() > 0;
|
||||
if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) {
|
||||
llama_memory_clear(mem_dft, false);
|
||||
prompt_dft.clear();
|
||||
@@ -393,50 +391,38 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
return;
|
||||
}
|
||||
|
||||
bool do_restore = false;
|
||||
if (prompt_dft.size() > prompt_cur.size() && reuse_i + reuse_n < (int64_t) prompt_dft.size()) {
|
||||
// This can happen after a partial acceptance (speculative decoding with checkpoints)
|
||||
LOG_DBG("%s: #prompt_dft=%zu, #prompt_cur=%zu, shorten draft\n",
|
||||
__func__, prompt_dft.size(), prompt_cur.size());
|
||||
prompt_dft.resize(prompt_cur.size());
|
||||
do_restore = true;
|
||||
}
|
||||
|
||||
if (reuse_i > 0) {
|
||||
GGML_ASSERT(!use_ckpt);
|
||||
|
||||
bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
|
||||
if (!is_removed) {
|
||||
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i);
|
||||
return;
|
||||
}
|
||||
llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
|
||||
|
||||
prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
|
||||
}
|
||||
|
||||
if (reuse_n < (int) prompt_dft.size() || do_restore) {
|
||||
if (reuse_n < (int) prompt_dft.size()) {
|
||||
if (use_ckpt) {
|
||||
if (ckpt.n_tokens > (int64_t) prompt_dft.size()) {
|
||||
LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n",
|
||||
__func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size());
|
||||
if (ckpt.n_tokens > 0) {
|
||||
LOG_DBG("%s: restoring checkpoint, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size());
|
||||
restore_checkpoint();
|
||||
reuse_n = ckpt.n_tokens;
|
||||
prompt_dft.resize(reuse_n);
|
||||
}
|
||||
draft_restore_checkpoint(ckpt.ckpt_size);
|
||||
reuse_n = ckpt.n_tokens;
|
||||
prompt_dft.resize(reuse_n);
|
||||
needs_ckpt = false;
|
||||
} else {
|
||||
bool is_removed = llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
|
||||
const bool is_removed = llama_memory_seq_rm(mem_dft, 0, reuse_n, -1);
|
||||
if (!is_removed) {
|
||||
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n",
|
||||
__func__, reuse_n, prompt_dft.size());
|
||||
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size());
|
||||
return;
|
||||
}
|
||||
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (needs_ckpt) {
|
||||
ckpt.ckpt_size = draft_create_checkpoint(prompt_dft.size(), batch.n_tokens);
|
||||
}
|
||||
|
||||
// prepare a batch to evaluate any new tokens in the prompt
|
||||
common_batch_clear(batch);
|
||||
|
||||
@@ -450,12 +436,17 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
// we should rarely end-up here during normal decoding
|
||||
if (batch.n_tokens > 0) {
|
||||
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
|
||||
LOG_DBG("%s: draft prompt batch: %d tokens\n", __func__, batch.n_tokens);
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0 && ret != 1) {
|
||||
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n",
|
||||
__func__, ret, prompt_cur.size());
|
||||
}
|
||||
|
||||
if (use_ckpt) {
|
||||
create_checkpoint(prompt_dft.size());
|
||||
}
|
||||
}
|
||||
|
||||
const llama_pos n_past = prompt_dft.size();
|
||||
@@ -467,7 +458,7 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
|
||||
prompt_dft.push_back(id_last);
|
||||
|
||||
LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
|
||||
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0 && ret != 1) {
|
||||
@@ -495,14 +486,14 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
|
||||
common_sampler_accept(smpl, id, true);
|
||||
|
||||
result.push_back(id);
|
||||
|
||||
if (sparams.n_max <= (int) result.size()) {
|
||||
// only collect very high-confidence draft tokens
|
||||
if (cur_p->data[0].p < sparams.p_min) {
|
||||
break;
|
||||
}
|
||||
|
||||
// only collect very high-confidence draft tokens
|
||||
if (cur_p->data[0].p < sparams.p_min) {
|
||||
result.push_back(id);
|
||||
|
||||
if (sparams.n_max <= (int) result.size()) {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -784,17 +775,15 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
||||
}
|
||||
|
||||
void accept(uint16_t n_accepted) override {
|
||||
if (verbose) {
|
||||
LOG_INF("%s: accepted %d tokens from %zu drafted tokens\n", __func__, n_accepted, n_draft_last);
|
||||
}
|
||||
|
||||
// compute acceptance fraction if we have a recorded draft length
|
||||
if (n_draft_last > 0) {
|
||||
const double f_acc = (double)n_accepted / (double)n_draft_last;
|
||||
if (f_acc < 0.5) {
|
||||
n_low++;
|
||||
if (n_low >= 3) {
|
||||
LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low);
|
||||
if (verbose) {
|
||||
LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low);
|
||||
}
|
||||
|
||||
mod.reset();
|
||||
n_low = 0;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -175,6 +175,7 @@ pre_computed_hashes = [
|
||||
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
|
||||
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
|
||||
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
|
||||
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openbmb/MiniCPM-V-4_6", "chkhsh": "1444df51289cfa8063b96f0e62b1125440111bc79a52003ea14b6eac7016fd5f"},
|
||||
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
|
||||
# jina-v2-de variants
|
||||
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},
|
||||
|
||||
49
docs/multimodal/minicpmv4.6.md
Normal file
49
docs/multimodal/minicpmv4.6.md
Normal file
@@ -0,0 +1,49 @@
|
||||
## MiniCPM-V 4.6
|
||||
|
||||
### Prepare models and code
|
||||
|
||||
Download [MiniCPM-V-4_6](https://huggingface.co/openbmb/MiniCPM-V-4_6) PyTorch model from huggingface to "MiniCPM-V-4_6" folder.
|
||||
|
||||
The model must be the standard `transformers` v5.7.0+ checkpoint (no `trust_remote_code`); the architecture in `config.json` is `MiniCPMV4_6ForConditionalGeneration` with a `qwen3_5_text` text model and a SigLIP-based vision tower plus a window-attention `vit_merger`.
|
||||
|
||||
### Build llama.cpp
|
||||
|
||||
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md)
|
||||
|
||||
Clone llama.cpp:
|
||||
```bash
|
||||
git clone https://github.com/ggml-org/llama.cpp
|
||||
cd llama.cpp
|
||||
```
|
||||
|
||||
Build llama.cpp using `CMake`:
|
||||
```bash
|
||||
cmake -B build
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
|
||||
### Usage of MiniCPM-V 4.6
|
||||
|
||||
Unlike older MiniCPM-V variants, MiniCPM-V 4.6 is converted directly through `convert_hf_to_gguf.py`. The same script is invoked twice on the original Hugging Face directory: once to produce the language-model GGUF and once with `--mmproj` to produce the multimodal projector GGUF.
|
||||
|
||||
```bash
|
||||
# language model
|
||||
python ./convert_hf_to_gguf.py ../MiniCPM-V-4_6 --outfile ../MiniCPM-V-4_6/ggml-model-f16.gguf
|
||||
|
||||
# multimodal projector (vision tower + window-attention vit_merger + DownsampleMLP merger)
|
||||
python ./convert_hf_to_gguf.py ../MiniCPM-V-4_6 --mmproj --outfile ../MiniCPM-V-4_6/mmproj-model-f16.gguf
|
||||
|
||||
# optional: quantize to Q4_K_M
|
||||
./build/bin/llama-quantize ../MiniCPM-V-4_6/ggml-model-f16.gguf ../MiniCPM-V-4_6/ggml-model-Q4_K_M.gguf Q4_K_M
|
||||
```
|
||||
|
||||
|
||||
Inference on Linux or Mac
|
||||
```bash
|
||||
# run in single-turn mode
|
||||
./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4_6/ggml-model-f16.gguf --mmproj ../MiniCPM-V-4_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
|
||||
# run in conversation mode
|
||||
./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4_6/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-4_6/mmproj-model-f16.gguf
|
||||
```
|
||||
@@ -33,18 +33,18 @@ An example to use this approach can be the rewriting of source code by a LLM.
|
||||
This implementation looks for the last n-gram in history that matches the current n-gram and creates a draft using the m tokens following the matched n-gram. It is the simplest self-speculative approach with minimal overhead.
|
||||
|
||||
```
|
||||
llama-server [...] --spec-type ngram-simple --draft-max 64
|
||||
llama-server [...] --spec-type ngram-simple --spec-draft-n-max 64
|
||||
```
|
||||
|
||||
#### n-gram Map Key (`ngram-map-k`)
|
||||
|
||||
This implementation looks for the current n-gram of size n (called the _key_) in the token history. If the key n-gram is followed by the same m tokens (called the _mgram_) multiple times, it creates a draft using these m tokens. This approach requires a minimum number of occurrences (argument `--spec-ngram-min-hits`, default is 1) before generating drafts.
|
||||
This implementation looks for the current n-gram of size n (called the _key_) in the token history. If the key n-gram is followed by the same m tokens (called the _mgram_) multiple times, it creates a draft using these m tokens. This approach requires a minimum number of occurrences (argument `--spec-ngram-map-k-min-hits`, default is 1) before generating drafts.
|
||||
|
||||
The number of accepted tokens is stored for each used n-gram.
|
||||
|
||||
**Example:**
|
||||
```
|
||||
llama-server [...] --spec-type ngram-map-k --draft-max 64
|
||||
llama-server [...] --spec-type ngram-map-k --spec-draft-n-max 64
|
||||
```
|
||||
|
||||
#### n-gram Map Key-4-Values (`ngram-map-k4v`)
|
||||
@@ -55,7 +55,7 @@ The number of accepted tokens is stored for each used n-gram.
|
||||
|
||||
**Example:** Server options to be used if there are a lot of longer repetitions.
|
||||
```
|
||||
llama-server [...] --spec-type ngram-map-k4v --spec-ngram-size-n 8 --spec-ngram-size-m 8 --spec-ngram-min-hits 2 --draft-max 64
|
||||
llama-server [...] --spec-type ngram-map-k4v --spec-ngram-map-k4v-size-n 8 --spec-ngram-map-k4v-size-m 8 --spec-ngram-map-k4v-min-hits 2 --spec-draft-n-max 64
|
||||
```
|
||||
|
||||
### n-gram Mod (`ngram-mod`)
|
||||
@@ -80,9 +80,9 @@ Currently, a single hash pool is shared across all server slots, so different re
|
||||
# notes:
|
||||
# - small `n` are not recommended
|
||||
# - MoEs require long drafts
|
||||
# - dense models: can reduce `--draft-min` and `--draft-max`
|
||||
# - dense models: can reduce `--spec-ngram-mod-n-min` and `--spec-ngram-mod-n-max`
|
||||
|
||||
llama-server ... --spec-type ngram-mod --spec-ngram-size-n 24 --draft-min 48 --draft-max 64
|
||||
llama-server ... --spec-type ngram-mod --spec-ngram-mod-n-match 24 --spec-ngram-mod-n-min 48 --spec-ngram-mod-n-max 64
|
||||
```
|
||||
|
||||
Applications:
|
||||
@@ -105,21 +105,90 @@ Example Video:
|
||||
|
||||
If a draft model is combined with a draftless decoding the draftless decoding has higher precedence.
|
||||
|
||||
### General Speculative Parameters
|
||||
|
||||
```
|
||||
--draft, --draft-n, --draft-max N number of tokens to draft for speculative decoding (default: 16)
|
||||
(env: LLAMA_ARG_DRAFT_MAX)
|
||||
--draft-min, --draft-n-min N minimum number of draft tokens to use for speculative decoding
|
||||
(default: 0)
|
||||
(env: LLAMA_ARG_DRAFT_MIN)
|
||||
[...]
|
||||
--spec-type [none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
|
||||
type of speculative decoding to use when no draft model is provided
|
||||
(default: none)
|
||||
--spec-ngram-size-n N ngram size N for ngram-simple/ngram-map speculative decoding, length
|
||||
of lookup n-gram (default: 12)
|
||||
--spec-ngram-size-m N ngram size M for ngram-simple/ngram-map speculative decoding, length
|
||||
of draft m-gram (default: 48)
|
||||
--spec-ngram-min-hits N minimum hits for ngram-map speculative decoding (default: 1)
|
||||
(env: LLAMA_ARG_SPEC_TYPE)
|
||||
--spec-default use default speculative decoding
|
||||
```
|
||||
|
||||
### Draft Model Parameters
|
||||
|
||||
```
|
||||
--spec-draft-model, -md, --model-draft FNAME
|
||||
draft model for speculative decoding (default: unused)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_MODEL)
|
||||
--spec-draft-hf, -hfd, -hfrd, --hf-repo-draft <user>/<model>[:quant]
|
||||
HuggingFace repository for the draft model
|
||||
--spec-draft-n-max N
|
||||
number of tokens to draft for speculative decoding (default: 16)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_N_MAX)
|
||||
--spec-draft-n-min N
|
||||
minimum number of draft tokens to use for speculative decoding (default: 0)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_N_MIN)
|
||||
--spec-draft-p-split, --draft-p-split P
|
||||
speculative decoding split probability (default: 0.10)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT)
|
||||
--spec-draft-p-min, --draft-p-min P
|
||||
minimum speculative decoding probability (greedy) (default: 0.75)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_P_MIN)
|
||||
--spec-draft-ctx-size, -cd, --ctx-size-draft N
|
||||
size of the prompt context for the draft model (default: 0, 0 = loaded from model)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE)
|
||||
--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N
|
||||
max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)
|
||||
(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT)
|
||||
--spec-draft-device, -devd, --device-draft <dev1,dev2,..>
|
||||
comma-separated list of devices to use for offloading the draft model
|
||||
--spec-draft-replace, --spec-replace TARGET DRAFT
|
||||
translate the string in TARGET into DRAFT if the draft model and main model are not compatible
|
||||
```
|
||||
|
||||
### n-gram Mod Parameters
|
||||
|
||||
```
|
||||
--spec-ngram-mod-n-match N
|
||||
ngram-mod lookup length (default: 24)
|
||||
--spec-ngram-mod-n-min N
|
||||
minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48)
|
||||
--spec-ngram-mod-n-max N
|
||||
maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64)
|
||||
```
|
||||
|
||||
### n-gram Simple Parameters
|
||||
|
||||
```
|
||||
--spec-ngram-simple-size-n N
|
||||
ngram size N for ngram-simple speculative decoding, length of lookup n-gram (default: 12)
|
||||
--spec-ngram-simple-size-m N
|
||||
ngram size M for ngram-simple speculative decoding, length of draft m-gram (default: 48)
|
||||
--spec-ngram-simple-min-hits N
|
||||
minimum hits for ngram-simple speculative decoding (default: 1)
|
||||
```
|
||||
|
||||
### n-gram Map Key Parameters
|
||||
|
||||
```
|
||||
--spec-ngram-map-k-size-n N
|
||||
ngram size N for ngram-map-k speculative decoding, length of lookup n-gram (default: 12)
|
||||
--spec-ngram-map-k-size-m N
|
||||
ngram size M for ngram-map-k speculative decoding, length of draft m-gram (default: 48)
|
||||
--spec-ngram-map-k-min-hits N
|
||||
minimum hits for ngram-map-k speculative decoding (default: 1)
|
||||
```
|
||||
|
||||
### n-gram Map Key-4-Values Parameters
|
||||
|
||||
```
|
||||
--spec-ngram-map-k4v-size-n N
|
||||
ngram size N for ngram-map-k4v speculative decoding, length of lookup n-gram (default: 12)
|
||||
--spec-ngram-map-k4v-size-m N
|
||||
ngram size M for ngram-map-k4v speculative decoding, length of draft m-gram (default: 48)
|
||||
--spec-ngram-map-k4v-min-hits N
|
||||
minimum hits for ngram-map-k4v speculative decoding (default: 1)
|
||||
```
|
||||
|
||||
### `--spec-type TYPE`
|
||||
@@ -140,21 +209,40 @@ Specifies a type of speculative decoding without draft model.
|
||||
./llama-server [...] --spec-type ngram-simple
|
||||
```
|
||||
|
||||
### `--spec-ngram-size-n N`
|
||||
### `--spec-ngram-*-size-n N`
|
||||
|
||||
Sets the size N of the lookup n-gram for n-gram map based speculative decoding.
|
||||
The n-gram size N determines how many tokens in a row to look back when searching for matching patterns.
|
||||
|
||||
### `--spec-ngram-size-m M`
|
||||
Each n-gram implementation has its own parameter:
|
||||
|
||||
- `--spec-ngram-simple-size-n` for `ngram-simple`
|
||||
- `--spec-ngram-map-k-size-n` for `ngram-map-k`
|
||||
- `--spec-ngram-map-k4v-size-n` for `ngram-map-k4v`
|
||||
- `--spec-ngram-mod-n-match` for `ngram-mod`
|
||||
|
||||
### `--spec-ngram-*-size-m M`
|
||||
|
||||
Sets the size M of the draft m-gram for n-gram map based speculative decoding.
|
||||
The m-gram size determines how many tokens to draft when a match is found.
|
||||
Larger values can provide more speedup but may reduce acceptance rate.
|
||||
|
||||
### `--spec-ngram-min-hits H`
|
||||
Each n-gram implementation has its own parameter:
|
||||
|
||||
- `--spec-ngram-simple-size-m` for `ngram-simple`
|
||||
- `--spec-ngram-map-k-size-m` for `ngram-map-k`
|
||||
- `--spec-ngram-map-k4v-size-m` for `ngram-map-k4v`
|
||||
|
||||
### `--spec-ngram-*-min-hits H`
|
||||
|
||||
This option defines how often a key has to appear in the token history to be used as a draft (default is 1).
|
||||
|
||||
Each n-gram implementation has its own parameter:
|
||||
|
||||
- `--spec-ngram-simple-min-hits` for `ngram-simple`
|
||||
- `--spec-ngram-map-k-min-hits` for `ngram-map-k`
|
||||
- `--spec-ngram-map-k4v-min-hits` for `ngram-map-k4v`
|
||||
|
||||
## Statistics
|
||||
Each speculative decoding implementation prints statistics.
|
||||
|
||||
@@ -180,4 +268,3 @@ statistics ngram_map_k: #calls(b,g,a) = 6 1690 26, #gen drafts = 26, #acc drafts
|
||||
- `#gen tokens`: number of tokens generated by this implementation (including rejected tokens)
|
||||
- `#acc tokens`: number of tokens accepted by the main model
|
||||
- `dur(b,g,a): durations of begin (new prompt), generation and accumulation (process acceptance).
|
||||
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
set(TARGET llama-diffusion)
|
||||
add_library(${TARGET} STATIC diffusion.cpp diffusion.h)
|
||||
target_link_libraries(${TARGET} PUBLIC llama llama-common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PUBLIC cxx_std_17)
|
||||
|
||||
set(TARGET llama-diffusion-cli)
|
||||
add_executable(${TARGET} diffusion-cli.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE llama llama-common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-diffusion llama llama-common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -12,11 +12,11 @@ The diffusion CLI supports various parameters to control the generation process:
|
||||
### Core Diffusion Parameters
|
||||
- `--diffusion-steps`: Number of diffusion steps (default: 256)
|
||||
- `--diffusion-algorithm`: Algorithm for token selection
|
||||
- `0`: ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006.
|
||||
- `1`: ENTROPY_BASED - Entropy-based selection
|
||||
- `2`: MARGIN_BASED - Margin-based selection
|
||||
- `3`: RANDOM - Random selection
|
||||
- `4`: CONFIDENCE_BASED - Confidence-based selection (default)
|
||||
- `0`: DIFFUSION_ALGORITHM_ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006.
|
||||
- `1`: DIFFUSION_ALGORITHM_ENTROPY_BASED - Entropy-based selection
|
||||
- `2`: DIFFUSION_ALGORITHM_MARGIN_BASED - Margin-based selection
|
||||
- `3`: DIFFUSION_ALGORITHM_RANDOM - Random selection
|
||||
- `4`: DIFFUSION_ALGORITHM_CONFIDENCE_BASED - Confidence-based selection (default)
|
||||
- More documentation here https://github.com/DreamLM/Dream
|
||||
- `--diffusion-visual`: Enable live visualization during generation
|
||||
|
||||
|
||||
@@ -1,127 +1,23 @@
|
||||
#include "arg.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "diffusion.h"
|
||||
#include "llama.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <limits.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <clocale>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
enum diffusion_algorithm { ORIGIN = 0, ENTROPY_BASED = 1, MARGIN_BASED = 2, RANDOM = 3, CONFIDENCE_BASED = 4 };
|
||||
|
||||
// Unified transfer scheduling methods
|
||||
enum transfer_schedule {
|
||||
TIMESTEP_BASED = 0, // Dream-style: (1.0 - s/t) * remaining
|
||||
BLOCK_BASED = 1, // LLaDA-style: process in blocks with get_num_transfer_tokens
|
||||
};
|
||||
|
||||
typedef bool (*diffusion_step_callback_t)(int32_t step,
|
||||
int32_t total_steps,
|
||||
const llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
void * user_data);
|
||||
|
||||
struct diffusion_params {
|
||||
int32_t steps = 0;
|
||||
float temperature = 0;
|
||||
llama_token mask_token_id = LLAMA_TOKEN_NULL;
|
||||
diffusion_step_callback_t step_callback = nullptr;
|
||||
void * step_callback_user_data = nullptr;
|
||||
int32_t seed = 0;
|
||||
bool visual_mode = false;
|
||||
bool shift_logits = false; // Shift logits by -1 after decode
|
||||
|
||||
float top_p = 0.;
|
||||
int32_t top_k = 0.;
|
||||
|
||||
diffusion_algorithm algorithm = CONFIDENCE_BASED;
|
||||
transfer_schedule schedule = TIMESTEP_BASED;
|
||||
|
||||
float cfg_scale = 0.; // Config scale for classifier-free guidance
|
||||
float eps = 0.; // Timestep scheduling
|
||||
int32_t block_length = 0; // Block size (for block scheduling)
|
||||
float alg_temp = 0; // algorithm temperature (0.0 = deterministic)
|
||||
bool add_gumbel_noise = false; // Add gumbel noise to the logits if temp > 0.0
|
||||
|
||||
int32_t max_length = 0; // Maximum sequence length
|
||||
};
|
||||
|
||||
struct callback_data {
|
||||
diffusion_params * diff_params;
|
||||
const llama_vocab * vocab;
|
||||
int32_t n_input;
|
||||
};
|
||||
|
||||
static float calculate_confidence(const llama_token_data_array & cur_p,
|
||||
diffusion_algorithm algorithm,
|
||||
std::mt19937 & rng) {
|
||||
switch (algorithm) {
|
||||
case CONFIDENCE_BASED:
|
||||
return cur_p.data[cur_p.selected].p; // Selected token probability
|
||||
|
||||
case ENTROPY_BASED:
|
||||
{
|
||||
float entropy = 0.0f;
|
||||
const float epsilon = 1e-10f;
|
||||
for (size_t i = 0; i < cur_p.size; i++) {
|
||||
float prob = cur_p.data[i].p;
|
||||
entropy += prob * logf(prob + epsilon);
|
||||
}
|
||||
return -entropy; // Higher entropy = lower confidence
|
||||
}
|
||||
|
||||
case MARGIN_BASED:
|
||||
return (cur_p.size > 1) ? cur_p.data[0].p - cur_p.data[1].p : cur_p.data[0].p;
|
||||
|
||||
case RANDOM:
|
||||
{
|
||||
std::uniform_real_distribution<float> uniform(0.0f, 1.0f);
|
||||
return uniform(rng); // Random confidence
|
||||
}
|
||||
|
||||
case ORIGIN:
|
||||
return cur_p.data[cur_p.selected].p;
|
||||
|
||||
default:
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// Unified transfer count calculation function
|
||||
static int32_t calculate_transfer_count(int32_t step,
|
||||
int32_t total_steps,
|
||||
int32_t remaining_masked,
|
||||
transfer_schedule schedule,
|
||||
float eps,
|
||||
const std::vector<int32_t> & num_transfer_tokens = {}) {
|
||||
switch (schedule) {
|
||||
case TIMESTEP_BASED:
|
||||
{
|
||||
float t = 1.0f - (float) step / total_steps * (1.0f - eps);
|
||||
float s = 1.0f - (float) (step + 1) / total_steps * (1.0f - eps);
|
||||
float p_transfer = (step < total_steps - 1) ? (1.0f - s / t) : 1.0f;
|
||||
return (int32_t) (remaining_masked * p_transfer);
|
||||
}
|
||||
|
||||
case BLOCK_BASED:
|
||||
if (!num_transfer_tokens.empty() && step < (int32_t) num_transfer_tokens.size()) {
|
||||
return num_transfer_tokens[step];
|
||||
}
|
||||
return remaining_masked / (total_steps - step); // Fallback
|
||||
|
||||
default:
|
||||
return remaining_masked / (total_steps - step);
|
||||
}
|
||||
}
|
||||
|
||||
static bool diffusion_step_callback(int32_t step,
|
||||
int32_t total_steps,
|
||||
const llama_token * tokens,
|
||||
@@ -176,341 +72,6 @@ static bool diffusion_step_callback(int32_t step,
|
||||
return true;
|
||||
}
|
||||
|
||||
static void add_gumbel_noise(float * logits, int32_t n_vocab, float temperature, std::mt19937 & rng) {
|
||||
if (temperature == 0.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::uniform_real_distribution<double> uniform(0.0, 1.0);
|
||||
for (int32_t i = 0; i < n_vocab; i++) {
|
||||
double noise = uniform(rng);
|
||||
// Prevent log(0)
|
||||
noise = std::max(noise, 1e-20);
|
||||
double gumbel_noise = std::pow(-std::log(noise), temperature);
|
||||
logits[i] = std::exp(logits[i]) / gumbel_noise;
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<int32_t> get_num_transfer_tokens(int32_t mask_count, int32_t steps) {
|
||||
std::vector<int32_t> num_transfer_tokens(steps);
|
||||
|
||||
int32_t base = mask_count / steps;
|
||||
int32_t remainder = mask_count % steps;
|
||||
|
||||
for (int32_t i = 0; i < steps; i++) {
|
||||
num_transfer_tokens[i] = base + (i < remainder ? 1 : 0);
|
||||
}
|
||||
|
||||
return num_transfer_tokens;
|
||||
}
|
||||
|
||||
static void diffusion_generate(llama_context * ctx,
|
||||
const llama_token * input_tokens,
|
||||
llama_token * output_tokens,
|
||||
int32_t n_input,
|
||||
const diffusion_params & params,
|
||||
int32_t & n_generated) {
|
||||
n_generated = 0;
|
||||
if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || params.max_length <= n_input) {
|
||||
return;
|
||||
}
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
|
||||
// Initialize with input and pad with mask tokens
|
||||
std::copy(input_tokens, input_tokens + n_input, output_tokens);
|
||||
std::fill(output_tokens + n_input, output_tokens + params.max_length, params.mask_token_id);
|
||||
|
||||
std::mt19937 rng(params.seed);
|
||||
|
||||
llama_set_causal_attn(ctx, false);
|
||||
|
||||
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
|
||||
|
||||
std::vector<llama_token_data> candidates(n_vocab);
|
||||
std::vector<llama_token_data> conf_candidates;
|
||||
conf_candidates.reserve(params.max_length);
|
||||
std::vector<int32_t> mask_positions;
|
||||
mask_positions.reserve(params.max_length);
|
||||
|
||||
// Setup sampler chain
|
||||
struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
|
||||
if (params.top_k > 0) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
|
||||
}
|
||||
if (params.top_p < 1.0f) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1));
|
||||
}
|
||||
if (params.temperature > 0.0f) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature));
|
||||
}
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed));
|
||||
|
||||
struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
|
||||
|
||||
llama_batch batch = llama_batch_init(params.max_length, 0, 1);
|
||||
batch.n_tokens = params.max_length;
|
||||
|
||||
// Pre-allocate buffers for CFG if needed
|
||||
int32_t logits_size = n_vocab * params.max_length;
|
||||
std::vector<float> cond_logits_buffer;
|
||||
std::vector<llama_token> un_x_buffer;
|
||||
if (params.cfg_scale > 0.0f) {
|
||||
cond_logits_buffer.resize(logits_size);
|
||||
un_x_buffer.resize(params.max_length);
|
||||
}
|
||||
|
||||
// For block-based processing
|
||||
std::vector<int32_t> num_transfer_tokens;
|
||||
int32_t num_blocks = 1;
|
||||
int32_t steps_per_block = params.steps;
|
||||
|
||||
if (params.schedule == BLOCK_BASED) {
|
||||
GGML_ASSERT(params.max_length % params.block_length == 0);
|
||||
num_blocks = params.max_length / params.block_length;
|
||||
GGML_ASSERT(params.steps % num_blocks == 0);
|
||||
steps_per_block = params.steps / num_blocks;
|
||||
}
|
||||
|
||||
std::vector<float> confidence(params.max_length);
|
||||
|
||||
int64_t total_sampling_time = 0;
|
||||
int64_t total_time = 0;
|
||||
int64_t time_start = ggml_time_us();
|
||||
|
||||
for (int block_num = 0; block_num < num_blocks; block_num++) {
|
||||
int32_t block_start = (params.schedule == BLOCK_BASED) ? n_input + block_num * params.block_length : 0;
|
||||
int32_t block_end = (params.schedule == BLOCK_BASED) ?
|
||||
std::min(n_input + (block_num + 1) * params.block_length, params.max_length) :
|
||||
params.max_length;
|
||||
|
||||
// Count masked tokens in current block for block-based processing
|
||||
if (params.schedule == BLOCK_BASED) {
|
||||
int32_t block_mask_count = 0;
|
||||
for (int i = block_start; i < block_end; i++) {
|
||||
if (output_tokens[i] == params.mask_token_id) {
|
||||
block_mask_count++;
|
||||
}
|
||||
}
|
||||
num_transfer_tokens = get_num_transfer_tokens(block_mask_count, steps_per_block);
|
||||
}
|
||||
|
||||
for (int32_t step = 0; step < steps_per_block; step++) {
|
||||
int32_t global_step = block_num * steps_per_block + step;
|
||||
|
||||
if (params.step_callback) {
|
||||
if (!params.step_callback(
|
||||
global_step, params.steps, output_tokens, params.max_length, params.step_callback_user_data)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Setup batch
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
batch.token[i] = output_tokens[i];
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = 0;
|
||||
batch.logits[i] = 1;
|
||||
}
|
||||
|
||||
float * logits = nullptr;
|
||||
|
||||
if (params.cfg_scale > 0.0f) {
|
||||
int ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("Failed to generate conditional");
|
||||
break;
|
||||
}
|
||||
float * cond_logits_ptr = llama_get_logits(ctx);
|
||||
std::memcpy(cond_logits_buffer.data(), cond_logits_ptr, logits_size * sizeof(float));
|
||||
|
||||
// Unconditional generation (mask input)
|
||||
std::copy(output_tokens, output_tokens + params.max_length, un_x_buffer.begin());
|
||||
for (int32_t i = 0; i < n_input; i++) {
|
||||
un_x_buffer[i] = params.mask_token_id;
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
batch.token[i] = un_x_buffer[i];
|
||||
}
|
||||
ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("Failed to generate unconditional");
|
||||
break;
|
||||
}
|
||||
float * uncond_logits = llama_get_logits(ctx);
|
||||
|
||||
// Apply CFG
|
||||
for (int32_t i = 0; i < logits_size; i++) {
|
||||
cond_logits_buffer[i] =
|
||||
uncond_logits[i] + (params.cfg_scale + 1.0f) * (cond_logits_buffer[i] - uncond_logits[i]);
|
||||
}
|
||||
logits = cond_logits_buffer.data();
|
||||
} else {
|
||||
int ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, global_step, ret);
|
||||
break;
|
||||
}
|
||||
logits = llama_get_logits(ctx);
|
||||
}
|
||||
|
||||
if (!logits) {
|
||||
LOG_ERR("%s: failed to get logits at step %d\n", __func__, global_step);
|
||||
break;
|
||||
}
|
||||
|
||||
auto get_logits_for_pos = [&](int32_t pos) -> const float * {
|
||||
if (params.shift_logits) {
|
||||
return pos == 0 ? logits : logits + (pos - 1) * n_vocab;
|
||||
}
|
||||
return logits + (pos) *n_vocab;
|
||||
};
|
||||
|
||||
int64_t time_start_sampling = ggml_time_us();
|
||||
|
||||
mask_positions.clear();
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
if (output_tokens[i] == params.mask_token_id) {
|
||||
// For block-based, only consider current block
|
||||
if (params.schedule != BLOCK_BASED || (i >= block_start && i < block_end)) {
|
||||
mask_positions.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (mask_positions.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (params.add_gumbel_noise && params.temperature > 0.0f) {
|
||||
add_gumbel_noise(logits, n_vocab, params.temperature, rng);
|
||||
}
|
||||
|
||||
if (params.algorithm == ORIGIN) {
|
||||
int32_t transfer_count = calculate_transfer_count(
|
||||
step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
|
||||
float p_transfer = (float) transfer_count / mask_positions.size();
|
||||
|
||||
for (int32_t pos : mask_positions) {
|
||||
if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
|
||||
const float * pos_logits = get_logits_for_pos(pos);
|
||||
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].id = token_id;
|
||||
candidates[token_id].logit = pos_logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
candidates.data(),
|
||||
(size_t) n_vocab,
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
output_tokens[pos] = cur_p.data[cur_p.selected].id;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::vector<std::pair<float, int32_t>> confidences;
|
||||
std::vector<llama_token> sampled_tokens(mask_positions.size());
|
||||
|
||||
for (size_t i = 0; i < mask_positions.size(); i++) {
|
||||
int32_t pos = mask_positions[i];
|
||||
const float * pos_logits = get_logits_for_pos(pos);
|
||||
|
||||
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].logit = pos_logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
candidates[token_id].id = token_id;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
candidates.data(),
|
||||
candidates.size(),
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
llama_token sampled_token = cur_p.data[cur_p.selected].id;
|
||||
|
||||
float conf = calculate_confidence(cur_p, params.algorithm, rng);
|
||||
|
||||
sampled_tokens[i] = sampled_token;
|
||||
confidences.emplace_back(conf, i);
|
||||
}
|
||||
|
||||
int32_t transfer_count = calculate_transfer_count(
|
||||
step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
|
||||
|
||||
if (transfer_count > 0) {
|
||||
if (params.alg_temp == 0.0f) {
|
||||
std::partial_sort(confidences.begin(),
|
||||
confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()),
|
||||
confidences.end(),
|
||||
[](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
|
||||
if (a.first != b.first) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
return a.second < b.second;
|
||||
});
|
||||
|
||||
for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
|
||||
int32_t mask_idx = confidences[i].second;
|
||||
int32_t pos = mask_positions[mask_idx];
|
||||
output_tokens[pos] = sampled_tokens[mask_idx];
|
||||
}
|
||||
} else {
|
||||
conf_candidates.clear();
|
||||
for (size_t i = 0; i < confidences.size(); i++) {
|
||||
float conf_logit = confidences[i].first / params.alg_temp;
|
||||
conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f });
|
||||
}
|
||||
|
||||
llama_token_data_array conf_array = {
|
||||
conf_candidates.data(),
|
||||
conf_candidates.size(),
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
|
||||
llama_sampler_apply(dist_sampler, &conf_array);
|
||||
int32_t selected_idx = conf_array.selected;
|
||||
int32_t mask_idx = selected_idx;
|
||||
int32_t pos = mask_positions[mask_idx];
|
||||
output_tokens[pos] = sampled_tokens[mask_idx];
|
||||
|
||||
conf_candidates[selected_idx].p = 0.0f;
|
||||
conf_array.selected = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int64_t time_end_sampling = ggml_time_us();
|
||||
total_sampling_time += time_end_sampling - time_start_sampling;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t time_end = ggml_time_us();
|
||||
total_time += time_end - time_start;
|
||||
|
||||
LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
|
||||
total_time / 1000.0,
|
||||
total_time / 1000.0 / params.steps,
|
||||
total_sampling_time / 1000.0 / params.steps);
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_sampler_free(sampler);
|
||||
llama_sampler_free(dist_sampler);
|
||||
|
||||
n_generated = params.max_length;
|
||||
}
|
||||
|
||||
static std::string format_input_text(const std::string & prompt, const std::string & system_prompt, bool use_chat_template, llama_model * model) {
|
||||
if (!use_chat_template) {
|
||||
return prompt;
|
||||
@@ -631,10 +192,10 @@ int main(int argc, char ** argv) {
|
||||
GGML_ASSERT((params.diffusion.eps == 0) ^ (params.diffusion.block_length == 0));
|
||||
|
||||
if (params.diffusion.eps) {
|
||||
diff_params.schedule = TIMESTEP_BASED;
|
||||
diff_params.schedule = DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED;
|
||||
diff_params.eps = params.diffusion.eps;
|
||||
} else if (params.diffusion.block_length) {
|
||||
diff_params.schedule = BLOCK_BASED;
|
||||
diff_params.schedule = DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED;
|
||||
diff_params.block_length = params.diffusion.block_length;
|
||||
}
|
||||
|
||||
@@ -653,8 +214,17 @@ int main(int argc, char ** argv) {
|
||||
callback_data cb_data = { &diff_params, vocab, n_input };
|
||||
diff_params.step_callback_user_data = &cb_data;
|
||||
|
||||
const char * alg_names[] = { "ORIGIN", "ENTROPY_BASED", "MARGIN_BASED", "RANDOM", "CONFIDENCE_BASED" };
|
||||
const char * sched_names[] = { "TIMESTEP_BASED", "BLOCK_BASED" };
|
||||
const char * alg_names[] = {
|
||||
"DIFFUSION_ALGORITHM_ORIGIN",
|
||||
"DIFFUSION_ALGORITHM_ENTROPY_BASED",
|
||||
"DIFFUSION_ALGORITHM_MARGIN_BASED",
|
||||
"DIFFUSION_ALGORITHM_RANDOM",
|
||||
"DIFFUSION_ALGORITHM_CONFIDENCE_BASED",
|
||||
};
|
||||
const char * sched_names[] = {
|
||||
"DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED",
|
||||
"DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED",
|
||||
};
|
||||
const char * alg_name =
|
||||
(diff_params.algorithm >= 0 && diff_params.algorithm <= 4) ? alg_names[diff_params.algorithm] : "UNKNOWN";
|
||||
const char * sched_name =
|
||||
@@ -666,11 +236,11 @@ int main(int argc, char ** argv) {
|
||||
LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "algorithm", diff_params.algorithm, alg_name);
|
||||
LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "schedule", diff_params.schedule, sched_name);
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "temperature", diff_params.temperature);
|
||||
if (diff_params.schedule == TIMESTEP_BASED) {
|
||||
if (diff_params.schedule == DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED) {
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", diff_params.eps);
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", diff_params.alg_temp);
|
||||
}
|
||||
if (diff_params.schedule == BLOCK_BASED) {
|
||||
if (diff_params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) {
|
||||
LOG_INF("diffusion_params: - %-25s u32 = %d\n", "block_length", diff_params.block_length);
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "cfg_scale", diff_params.cfg_scale);
|
||||
}
|
||||
|
||||
408
examples/diffusion/diffusion.cpp
Normal file
408
examples/diffusion/diffusion.cpp
Normal file
@@ -0,0 +1,408 @@
|
||||
#include "diffusion.h"
|
||||
|
||||
#include "log.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <random>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
static float calculate_confidence(const llama_token_data_array & cur_p,
|
||||
diffusion_algorithm algorithm,
|
||||
std::mt19937 & rng) {
|
||||
switch (algorithm) {
|
||||
case DIFFUSION_ALGORITHM_CONFIDENCE_BASED:
|
||||
return cur_p.data[cur_p.selected].p; // Selected token probability
|
||||
|
||||
case DIFFUSION_ALGORITHM_ENTROPY_BASED:
|
||||
{
|
||||
float entropy = 0.0f;
|
||||
const float epsilon = 1e-10f;
|
||||
for (size_t i = 0; i < cur_p.size; i++) {
|
||||
float prob = cur_p.data[i].p;
|
||||
entropy += prob * logf(prob + epsilon);
|
||||
}
|
||||
return -entropy; // Higher entropy = lower confidence
|
||||
}
|
||||
|
||||
case DIFFUSION_ALGORITHM_MARGIN_BASED:
|
||||
return (cur_p.size > 1) ? cur_p.data[0].p - cur_p.data[1].p : cur_p.data[0].p;
|
||||
|
||||
case DIFFUSION_ALGORITHM_RANDOM:
|
||||
{
|
||||
std::uniform_real_distribution<float> uniform(0.0f, 1.0f);
|
||||
return uniform(rng); // Random confidence
|
||||
}
|
||||
|
||||
case DIFFUSION_ALGORITHM_ORIGIN:
|
||||
return cur_p.data[cur_p.selected].p;
|
||||
|
||||
default:
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// Unified transfer count calculation function
|
||||
static int32_t calculate_transfer_count(int32_t step,
|
||||
int32_t total_steps,
|
||||
int32_t remaining_masked,
|
||||
diffusion_transfer_schedule schedule,
|
||||
float eps,
|
||||
const std::vector<int32_t> & num_transfer_tokens = {}) {
|
||||
switch (schedule) {
|
||||
case DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED:
|
||||
{
|
||||
float t = 1.0f - (float) step / total_steps * (1.0f - eps);
|
||||
float s = 1.0f - (float) (step + 1) / total_steps * (1.0f - eps);
|
||||
float p_transfer = (step < total_steps - 1) ? (1.0f - s / t) : 1.0f;
|
||||
return (int32_t) (remaining_masked * p_transfer);
|
||||
}
|
||||
|
||||
case DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED:
|
||||
if (!num_transfer_tokens.empty() && step < (int32_t) num_transfer_tokens.size()) {
|
||||
return num_transfer_tokens[step];
|
||||
}
|
||||
return remaining_masked / (total_steps - step); // Fallback
|
||||
|
||||
default:
|
||||
return remaining_masked / (total_steps - step);
|
||||
}
|
||||
}
|
||||
|
||||
static void add_gumbel_noise(float * logits, int32_t n_vocab, float temperature, std::mt19937 & rng) {
|
||||
if (temperature == 0.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::uniform_real_distribution<double> uniform(0.0, 1.0);
|
||||
for (int32_t i = 0; i < n_vocab; i++) {
|
||||
double noise = uniform(rng);
|
||||
// Prevent log(0)
|
||||
noise = std::max(noise, 1e-20);
|
||||
double gumbel_noise = std::pow(-std::log(noise), temperature);
|
||||
logits[i] = std::exp(logits[i]) / gumbel_noise;
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<int32_t> get_num_transfer_tokens(int32_t mask_count, int32_t steps) {
|
||||
std::vector<int32_t> num_transfer_tokens(steps);
|
||||
|
||||
int32_t base = mask_count / steps;
|
||||
int32_t remainder = mask_count % steps;
|
||||
|
||||
for (int32_t i = 0; i < steps; i++) {
|
||||
num_transfer_tokens[i] = base + (i < remainder ? 1 : 0);
|
||||
}
|
||||
|
||||
return num_transfer_tokens;
|
||||
}
|
||||
|
||||
void diffusion_generate(llama_context * ctx,
|
||||
const llama_token * input_tokens,
|
||||
llama_token * output_tokens,
|
||||
int32_t n_input,
|
||||
const diffusion_params & params,
|
||||
int32_t & n_generated) {
|
||||
n_generated = 0;
|
||||
if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || params.max_length <= n_input) {
|
||||
return;
|
||||
}
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
|
||||
// Initialize with input and pad with mask tokens
|
||||
std::copy(input_tokens, input_tokens + n_input, output_tokens);
|
||||
std::fill(output_tokens + n_input, output_tokens + params.max_length, params.mask_token_id);
|
||||
|
||||
std::mt19937 rng(params.seed);
|
||||
|
||||
llama_set_causal_attn(ctx, false);
|
||||
|
||||
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
|
||||
|
||||
std::vector<llama_token_data> candidates(n_vocab);
|
||||
std::vector<llama_token_data> conf_candidates;
|
||||
conf_candidates.reserve(params.max_length);
|
||||
std::vector<int32_t> mask_positions;
|
||||
mask_positions.reserve(params.max_length);
|
||||
|
||||
// Setup sampler chain
|
||||
struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
|
||||
if (params.top_k > 0) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
|
||||
}
|
||||
if (params.top_p < 1.0f) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1));
|
||||
}
|
||||
if (params.temperature > 0.0f) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature));
|
||||
}
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed));
|
||||
|
||||
struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
|
||||
|
||||
llama_batch batch = llama_batch_init(params.max_length, 0, 1);
|
||||
batch.n_tokens = params.max_length;
|
||||
|
||||
// Pre-allocate buffers for CFG if needed
|
||||
int32_t logits_size = n_vocab * params.max_length;
|
||||
std::vector<float> cond_logits_buffer;
|
||||
std::vector<llama_token> un_x_buffer;
|
||||
if (params.cfg_scale > 0.0f) {
|
||||
cond_logits_buffer.resize(logits_size);
|
||||
un_x_buffer.resize(params.max_length);
|
||||
}
|
||||
|
||||
// For block-based processing
|
||||
std::vector<int32_t> num_transfer_tokens;
|
||||
int32_t num_blocks = 1;
|
||||
int32_t steps_per_block = params.steps;
|
||||
|
||||
if (params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) {
|
||||
GGML_ASSERT(params.max_length % params.block_length == 0);
|
||||
num_blocks = params.max_length / params.block_length;
|
||||
GGML_ASSERT(params.steps % num_blocks == 0);
|
||||
steps_per_block = params.steps / num_blocks;
|
||||
}
|
||||
|
||||
std::vector<float> confidence(params.max_length);
|
||||
|
||||
int64_t total_sampling_time = 0;
|
||||
int64_t total_time = 0;
|
||||
int64_t time_start = ggml_time_us();
|
||||
|
||||
for (int block_num = 0; block_num < num_blocks; block_num++) {
|
||||
int32_t block_start = (params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) ? n_input + block_num * params.block_length : 0;
|
||||
int32_t block_end = (params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) ?
|
||||
std::min(n_input + (block_num + 1) * params.block_length, params.max_length) :
|
||||
params.max_length;
|
||||
|
||||
// Count masked tokens in current block for block-based processing
|
||||
if (params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) {
|
||||
int32_t block_mask_count = 0;
|
||||
for (int i = block_start; i < block_end; i++) {
|
||||
if (output_tokens[i] == params.mask_token_id) {
|
||||
block_mask_count++;
|
||||
}
|
||||
}
|
||||
num_transfer_tokens = get_num_transfer_tokens(block_mask_count, steps_per_block);
|
||||
}
|
||||
|
||||
for (int32_t step = 0; step < steps_per_block; step++) {
|
||||
int32_t global_step = block_num * steps_per_block + step;
|
||||
|
||||
if (params.step_callback) {
|
||||
if (!params.step_callback(
|
||||
global_step, params.steps, output_tokens, params.max_length, params.step_callback_user_data)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Setup batch
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
batch.token[i] = output_tokens[i];
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = 0;
|
||||
batch.logits[i] = 1;
|
||||
}
|
||||
|
||||
float * logits = nullptr;
|
||||
|
||||
if (params.cfg_scale > 0.0f) {
|
||||
int ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("Failed to generate conditional");
|
||||
break;
|
||||
}
|
||||
float * cond_logits_ptr = llama_get_logits(ctx);
|
||||
std::memcpy(cond_logits_buffer.data(), cond_logits_ptr, logits_size * sizeof(float));
|
||||
|
||||
// Unconditional generation (mask input)
|
||||
std::copy(output_tokens, output_tokens + params.max_length, un_x_buffer.begin());
|
||||
for (int32_t i = 0; i < n_input; i++) {
|
||||
un_x_buffer[i] = params.mask_token_id;
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
batch.token[i] = un_x_buffer[i];
|
||||
}
|
||||
ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("Failed to generate unconditional");
|
||||
break;
|
||||
}
|
||||
float * uncond_logits = llama_get_logits(ctx);
|
||||
|
||||
// Apply CFG
|
||||
for (int32_t i = 0; i < logits_size; i++) {
|
||||
cond_logits_buffer[i] =
|
||||
uncond_logits[i] + (params.cfg_scale + 1.0f) * (cond_logits_buffer[i] - uncond_logits[i]);
|
||||
}
|
||||
logits = cond_logits_buffer.data();
|
||||
} else {
|
||||
int ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, global_step, ret);
|
||||
break;
|
||||
}
|
||||
logits = llama_get_logits(ctx);
|
||||
}
|
||||
|
||||
if (!logits) {
|
||||
LOG_ERR("%s: failed to get logits at step %d\n", __func__, global_step);
|
||||
break;
|
||||
}
|
||||
|
||||
auto get_logits_for_pos = [&](int32_t pos) -> const float * {
|
||||
if (params.shift_logits) {
|
||||
return pos == 0 ? logits : logits + (pos - 1) * n_vocab;
|
||||
}
|
||||
return logits + pos * n_vocab;
|
||||
};
|
||||
|
||||
int64_t time_start_sampling = ggml_time_us();
|
||||
|
||||
mask_positions.clear();
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
if (output_tokens[i] == params.mask_token_id) {
|
||||
// For block-based, only consider current block
|
||||
if (params.schedule != DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED || (i >= block_start && i < block_end)) {
|
||||
mask_positions.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (mask_positions.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (params.add_gumbel_noise && params.temperature > 0.0f) {
|
||||
add_gumbel_noise(logits, n_vocab, params.temperature, rng);
|
||||
}
|
||||
|
||||
if (params.algorithm == DIFFUSION_ALGORITHM_ORIGIN) {
|
||||
int32_t transfer_count = calculate_transfer_count(
|
||||
step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
|
||||
float p_transfer = (float) transfer_count / mask_positions.size();
|
||||
|
||||
for (int32_t pos : mask_positions) {
|
||||
if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
|
||||
const float * pos_logits = get_logits_for_pos(pos);
|
||||
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].id = token_id;
|
||||
candidates[token_id].logit = pos_logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
candidates.data(),
|
||||
(size_t) n_vocab,
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
output_tokens[pos] = cur_p.data[cur_p.selected].id;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::vector<std::pair<float, int32_t>> confidences;
|
||||
std::vector<llama_token> sampled_tokens(mask_positions.size());
|
||||
|
||||
for (size_t i = 0; i < mask_positions.size(); i++) {
|
||||
int32_t pos = mask_positions[i];
|
||||
const float * pos_logits = get_logits_for_pos(pos);
|
||||
|
||||
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].logit = pos_logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
candidates[token_id].id = token_id;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
candidates.data(),
|
||||
candidates.size(),
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
llama_token sampled_token = cur_p.data[cur_p.selected].id;
|
||||
|
||||
float conf = calculate_confidence(cur_p, params.algorithm, rng);
|
||||
|
||||
sampled_tokens[i] = sampled_token;
|
||||
confidences.emplace_back(conf, i);
|
||||
}
|
||||
|
||||
int32_t transfer_count = calculate_transfer_count(
|
||||
step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
|
||||
|
||||
if (transfer_count > 0) {
|
||||
if (params.alg_temp == 0.0f) {
|
||||
std::partial_sort(confidences.begin(),
|
||||
confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()),
|
||||
confidences.end(),
|
||||
[](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
|
||||
if (a.first != b.first) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
return a.second < b.second;
|
||||
});
|
||||
|
||||
for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
|
||||
int32_t mask_idx = confidences[i].second;
|
||||
int32_t pos = mask_positions[mask_idx];
|
||||
output_tokens[pos] = sampled_tokens[mask_idx];
|
||||
}
|
||||
} else {
|
||||
conf_candidates.clear();
|
||||
for (size_t i = 0; i < confidences.size(); i++) {
|
||||
float conf_logit = confidences[i].first / params.alg_temp;
|
||||
conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f });
|
||||
}
|
||||
|
||||
llama_token_data_array conf_array = {
|
||||
conf_candidates.data(),
|
||||
conf_candidates.size(),
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
|
||||
llama_sampler_apply(dist_sampler, &conf_array);
|
||||
int32_t selected_idx = conf_array.selected;
|
||||
int32_t mask_idx = selected_idx;
|
||||
int32_t pos = mask_positions[mask_idx];
|
||||
output_tokens[pos] = sampled_tokens[mask_idx];
|
||||
|
||||
conf_candidates[selected_idx].p = 0.0f;
|
||||
conf_array.selected = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int64_t time_end_sampling = ggml_time_us();
|
||||
total_sampling_time += time_end_sampling - time_start_sampling;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t time_end = ggml_time_us();
|
||||
total_time += time_end - time_start;
|
||||
|
||||
LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
|
||||
total_time / 1000.0,
|
||||
total_time / 1000.0 / params.steps,
|
||||
total_sampling_time / 1000.0 / params.steps);
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_sampler_free(sampler);
|
||||
llama_sampler_free(dist_sampler);
|
||||
|
||||
n_generated = params.max_length;
|
||||
}
|
||||
57
examples/diffusion/diffusion.h
Normal file
57
examples/diffusion/diffusion.h
Normal file
@@ -0,0 +1,57 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
enum diffusion_algorithm {
|
||||
DIFFUSION_ALGORITHM_ORIGIN = 0,
|
||||
DIFFUSION_ALGORITHM_ENTROPY_BASED = 1,
|
||||
DIFFUSION_ALGORITHM_MARGIN_BASED = 2,
|
||||
DIFFUSION_ALGORITHM_RANDOM = 3,
|
||||
DIFFUSION_ALGORITHM_CONFIDENCE_BASED = 4,
|
||||
};
|
||||
|
||||
// Unified transfer scheduling methods
|
||||
enum diffusion_transfer_schedule {
|
||||
DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED = 0, // Dream-style: (1.0 - s/t) * remaining
|
||||
DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED = 1, // LLaDA-style: process in blocks with get_num_transfer_tokens
|
||||
};
|
||||
|
||||
typedef bool (*diffusion_step_callback_t)(int32_t step,
|
||||
int32_t total_steps,
|
||||
const llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
void * user_data);
|
||||
|
||||
struct diffusion_params {
|
||||
int32_t steps = 0;
|
||||
float temperature = 0;
|
||||
llama_token mask_token_id = LLAMA_TOKEN_NULL;
|
||||
diffusion_step_callback_t step_callback = nullptr;
|
||||
void * step_callback_user_data = nullptr;
|
||||
int32_t seed = 0;
|
||||
bool visual_mode = false;
|
||||
bool shift_logits = false; // Shift logits by -1 after decode
|
||||
|
||||
float top_p = 0.;
|
||||
int32_t top_k = 0.;
|
||||
|
||||
diffusion_algorithm algorithm = DIFFUSION_ALGORITHM_CONFIDENCE_BASED;
|
||||
diffusion_transfer_schedule schedule = DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED;
|
||||
|
||||
float cfg_scale = 0.; // Config scale for classifier-free guidance
|
||||
float eps = 0.; // Timestep scheduling
|
||||
int32_t block_length = 0; // Block size (for block scheduling)
|
||||
float alg_temp = 0; // algorithm temperature (0.0 = deterministic)
|
||||
bool add_gumbel_noise = false; // Add gumbel noise to the logits if temp > 0.0
|
||||
|
||||
int32_t max_length = 0; // Maximum sequence length
|
||||
};
|
||||
|
||||
void diffusion_generate(llama_context * ctx,
|
||||
const llama_token * input_tokens,
|
||||
llama_token * output_tokens,
|
||||
int32_t n_input,
|
||||
const diffusion_params & params,
|
||||
int32_t & n_generated);
|
||||
@@ -38,8 +38,12 @@ int main(int argc, char ** argv) {
|
||||
std::string result0;
|
||||
std::string result1;
|
||||
std::string result2;
|
||||
std::string result3;
|
||||
|
||||
// init
|
||||
|
||||
ggml_backend_load_all();
|
||||
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
auto * model = llama_init->model();
|
||||
@@ -213,11 +217,83 @@ int main(int argc, char ** argv) {
|
||||
n_past += 1;
|
||||
}
|
||||
|
||||
// test on-device state save/load
|
||||
auto params_ctx4 = common_context_params_to_llama(params);
|
||||
params_ctx4.n_seq_max = 2;
|
||||
llama_context * ctx4 = llama_init_from_model(model, params_ctx4);
|
||||
|
||||
llama_sampler * smpl4 = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl4, llama_sampler_init_dist(params.sampling.seed));
|
||||
|
||||
printf("\nsingle seq run: %s", params.prompt.c_str());
|
||||
|
||||
// load state (rng, logits, embedding and kv_cache) from file
|
||||
n_token_count_out = 0;
|
||||
|
||||
if (!llama_state_load_file(ctx4, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
|
||||
fprintf(stderr, "\n%s : failed to load state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
|
||||
|
||||
// restore state (last tokens)
|
||||
n_past = n_token_count_out;
|
||||
if (!common_replay_last_token(ctx4, tokens.back(), n_past)) {
|
||||
return 1;
|
||||
}
|
||||
++n_past;
|
||||
|
||||
// save seq 0 and load into seq 1
|
||||
{
|
||||
// save kv of seq 0
|
||||
std::vector<uint8_t> seq_store(llama_state_seq_get_size_ext(ctx4, 0, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE));
|
||||
const size_t ncopy = llama_state_seq_get_data_ext(ctx4, seq_store.data(), seq_store.size(), 0, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (ncopy != seq_store.size()) {
|
||||
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
|
||||
return 1;
|
||||
}
|
||||
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
|
||||
|
||||
// erase whole kv
|
||||
llama_memory_clear(llama_get_memory(ctx4), true);
|
||||
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
||||
|
||||
// restore kv into seq 0
|
||||
const size_t nset = llama_state_seq_set_data_ext(ctx4, seq_store.data(), seq_store.size(), 1, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (nset != seq_store.size()) {
|
||||
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
|
||||
return 1;
|
||||
}
|
||||
fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset);
|
||||
}
|
||||
|
||||
// forth run
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
auto next_token = llama_sampler_sample(smpl4, ctx4, -1);
|
||||
auto next_token_str = common_token_to_piece(ctx4, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
result3 += next_token_str;
|
||||
|
||||
common_batch_clear(batch);
|
||||
common_batch_add(batch, next_token, n_past, {1}, true);
|
||||
|
||||
if (llama_decode(ctx4, batch)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_batch_free(batch);
|
||||
return 1;
|
||||
}
|
||||
n_past += 1;
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
|
||||
llama_sampler_free(smpl);
|
||||
llama_sampler_free(smpl2);
|
||||
llama_sampler_free(smpl3);
|
||||
llama_sampler_free(smpl4);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
@@ -226,12 +302,18 @@ int main(int argc, char ** argv) {
|
||||
|
||||
llama_free(ctx2);
|
||||
llama_free(ctx3);
|
||||
llama_free(ctx4);
|
||||
|
||||
if (result0 != result2) {
|
||||
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (result0 != result3) {
|
||||
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n%s : success\n", __func__);
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -110,13 +110,21 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (
|
||||
llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
|
||||
llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
|
||||
llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
|
||||
llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)
|
||||
) {
|
||||
LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
|
||||
if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
|
||||
(llama_vocab_get_add_bos(vocab_tgt) && llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft))) {
|
||||
LOG_ERR("%s: draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
|
||||
__func__,
|
||||
llama_vocab_get_add_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_dft),
|
||||
llama_vocab_bos(vocab_tgt), llama_vocab_bos(vocab_dft));
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
|
||||
(llama_vocab_get_add_eos(vocab_tgt) && llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft))) {
|
||||
LOG_ERR("%s: draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
|
||||
__func__,
|
||||
llama_vocab_get_add_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_dft),
|
||||
llama_vocab_eos(vocab_tgt), llama_vocab_eos(vocab_dft));
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -137,11 +145,12 @@ int main(int argc, char ** argv) {
|
||||
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
||||
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
|
||||
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
|
||||
|
||||
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
||||
LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__);
|
||||
LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i,
|
||||
common_token_to_piece(ctx_tgt, i).c_str(),
|
||||
common_token_to_piece(ctx_dft, i).c_str());
|
||||
common_token_to_piece(vocab_tgt, i).c_str(),
|
||||
common_token_to_piece(vocab_dft, i).c_str());
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,7 +119,7 @@ if [ $GGML_SYCL_DEVICE -ne -1 ]; then
|
||||
echo "Use $GGML_SYCL_DEVICE as main GPU"
|
||||
#use signle GPU only
|
||||
GPUS_SETTING="-mg $GGML_SYCL_DEVICE -sm ${SPLIT_MODE}"
|
||||
export ONEAPI_DEVICE_SELECTOR="level_zero:${$GGML_SYCL_DEVICE}"
|
||||
export ONEAPI_DEVICE_SELECTOR="level_zero:${GGML_SYCL_DEVICE}"
|
||||
echo "ONEAPI_DEVICE_SELECTOR=${ONEAPI_DEVICE_SELECTOR}"
|
||||
else
|
||||
echo "Use all Intel GPUs, including iGPU & dGPU"
|
||||
|
||||
@@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
|
||||
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 10)
|
||||
set(GGML_VERSION_MINOR 11)
|
||||
set(GGML_VERSION_PATCH 0)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
|
||||
@@ -438,6 +438,12 @@ extern "C" {
|
||||
GGML_PREC_F32 = 10,
|
||||
};
|
||||
|
||||
// op hint
|
||||
enum ggml_op_hint {
|
||||
GGML_HINT_NONE = 0,
|
||||
GGML_HINT_SRC0_IS_HADAMARD = 1,
|
||||
};
|
||||
|
||||
// model file types
|
||||
enum ggml_ftype {
|
||||
GGML_FTYPE_UNKNOWN = -1,
|
||||
@@ -1419,6 +1425,11 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_prec prec);
|
||||
|
||||
// change the hint of a matrix multiplication
|
||||
GGML_API void ggml_mul_mat_set_hint(
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_op_hint hint);
|
||||
|
||||
// indirect matrix multiplication
|
||||
GGML_API struct ggml_tensor * ggml_mul_mat_id(
|
||||
struct ggml_context * ctx,
|
||||
|
||||
@@ -1826,7 +1826,24 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
|
||||
continue;
|
||||
}
|
||||
|
||||
i = get_i_delayed(i);
|
||||
const int i_delayed = get_i_delayed(i);
|
||||
|
||||
// If we can delay the AllReduce we need to consider the interaction with zero-sized tensor slices.
|
||||
// A backend with such a slice would normally have valid data after participating in the AllReduce with a node that has
|
||||
// its compute flag disabled and thus gets its data zeroed out.
|
||||
// If the AllReduce is delayed then the nodes until that point also need to have their compute flag disabled.
|
||||
if (i_delayed > i) {
|
||||
for (size_t j = 0; j < n_backends; j++) {
|
||||
auto & bcj = backend_ctx->backend_configs[j];
|
||||
if ((bcj.nodes[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||
for (int ii = i + 1; ii <= i_delayed; ii++) {
|
||||
bcj.nodes[ii]->flags &= ~GGML_TENSOR_FLAG_COMPUTE;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
i = i_delayed;
|
||||
|
||||
for (size_t j = 0; j < n_backends; j++) {
|
||||
auto & bcj = backend_ctx->backend_configs[j];
|
||||
@@ -2083,8 +2100,8 @@ static const ggml_backend_i ggml_backend_meta_i = {
|
||||
/* .free = */ ggml_backend_meta_free,
|
||||
/* .set_tensor_async = */ ggml_backend_meta_set_tensor_async,
|
||||
/* .get_tensor_async = */ ggml_backend_meta_get_tensor_async,
|
||||
/* .get_tensor_2d_async = */ nullptr,
|
||||
/* .set_tensor_2d_async = */ nullptr,
|
||||
/* .get_tensor_2d_async = */ nullptr,
|
||||
/* .cpy_tensor_async = */ nullptr,
|
||||
/* .synchronize = */ ggml_backend_meta_synchronize,
|
||||
/* .graph_plan_create = */ nullptr,
|
||||
|
||||
@@ -262,9 +262,9 @@ static struct ggml_backend_i blas_backend_i = {
|
||||
/* .get_name = */ ggml_backend_blas_get_name,
|
||||
/* .free = */ ggml_backend_blas_free,
|
||||
/* .set_tensor_async = */ NULL,
|
||||
/* .get_tensor_2d_async = */ NULL,
|
||||
/* .set_tensor_2d_async = */ NULL,
|
||||
/* .get_tensor_async = */ NULL,
|
||||
/* .set_tensor_2d_async = */ NULL,
|
||||
/* .get_tensor_2d_async = */ NULL,
|
||||
/* .cpy_tensor_async = */ NULL,
|
||||
/* .synchronize = */ NULL,
|
||||
/* .graph_plan_create = */ NULL,
|
||||
|
||||
@@ -2746,8 +2746,8 @@ static const ggml_backend_i ggml_backend_cann_interface = {
|
||||
/* .free = */ ggml_backend_cann_free,
|
||||
/* .set_tensor_async = */ ggml_backend_cann_set_tensor_async,
|
||||
/* .get_tensor_async = */ ggml_backend_cann_get_tensor_async,
|
||||
/* .get_tensor_2d_async = */ NULL,
|
||||
/* .set_tensor_2d_async = */ NULL,
|
||||
/* .get_tensor_2d_async = */ NULL,
|
||||
/* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async,
|
||||
/* .synchronize = */ ggml_backend_cann_synchronize,
|
||||
/* .graph_plan_create = */ NULL,
|
||||
|
||||
@@ -485,6 +485,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
if (GGML_RV_ZIHINTPAUSE)
|
||||
string(APPEND MARCH_STR "_zihintpause")
|
||||
endif()
|
||||
if (GGML_CPU_RISCV64_SPACEMIT)
|
||||
# `xsmtvdotii' is only required for GCC >= 15.
|
||||
if (CMAKE_C_COMPILER_ID STREQUAL "GNU" AND
|
||||
CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 15)
|
||||
string(APPEND MARCH_STR "_xsmtvdotii")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
|
||||
else()
|
||||
@@ -571,13 +578,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
|
||||
# Fetch KleidiAI sources:
|
||||
include(FetchContent)
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.22.0")
|
||||
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "54049037570ab0ee0a0d126b2ba5ece1")
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.24.0")
|
||||
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/releases/download/${KLEIDIAI_COMMIT_TAG}/kleidiai-${KLEIDIAI_COMMIT_TAG}-src.tar.gz")
|
||||
set(KLEIDIAI_RELEASE_ARCHIVE_MD5 "2f02ebe29573d45813e671eb304f2a00")
|
||||
|
||||
set(KLEIDIAI_FETCH_ARGS
|
||||
URL ${KLEIDIAI_DOWNLOAD_URL}
|
||||
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5}
|
||||
URL_HASH MD5=${KLEIDIAI_RELEASE_ARCHIVE_MD5}
|
||||
)
|
||||
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
|
||||
list(APPEND KLEIDIAI_FETCH_ARGS DOWNLOAD_EXTRACT_TIMESTAMP NEW)
|
||||
|
||||
@@ -5023,6 +5023,71 @@ void ggml_gemm_q8_0_4x8_q8_0(int n,
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
if (svcntb() * 8 == 256) {
|
||||
const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
|
||||
|
||||
static const uint32_t idx_arr[8] = {0, 1, 4, 5, 2, 3, 6, 7};
|
||||
svuint32_t idx = svld1(svptrue_b32(), idx_arr);
|
||||
static const uint32_t idx_arr1[8] = {0, 1, 2, 3, 1, 2, 3, 0};
|
||||
svuint32_t idx_sc1 = svld1(svptrue_b32(), idx_arr1);
|
||||
static const uint32_t idx_arr2[8] = {0, 1, 2, 3, 0, 1, 2, 3};
|
||||
svuint32_t idx_sc2 = svld1(svptrue_b32(), idx_arr2);
|
||||
|
||||
for (int y = 0; y < nr; y += 4) {
|
||||
const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb;
|
||||
|
||||
for (int x = 0; x < nc; x += ncols_interleaved) {
|
||||
const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb;
|
||||
const block_q8_0x4 * a_ptr = a_ptr_base;
|
||||
|
||||
svfloat32_t acc_f32_01 = svdup_f32(0);
|
||||
svfloat32_t acc_f32_23 = svdup_f32(0);
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
|
||||
svint32_t acc_01 = svdup_s32(0);
|
||||
svint32_t acc_23 = svdup_s32(0);
|
||||
|
||||
// Process 4 chunks of 8 positions each
|
||||
for (int chunk = 0; chunk < 4; chunk++) {
|
||||
svint8_t s_a01 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32);
|
||||
svint8_t s_a23 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32 + 16);
|
||||
svint8_t s_b0123 = svld1_s8(svptrue_b8(), b_ptr->qs + chunk * 32);
|
||||
|
||||
acc_01 = svmmla_s32(acc_01, s_a01, s_b0123);
|
||||
acc_23 = svmmla_s32(acc_23, s_a23, s_b0123);
|
||||
}
|
||||
|
||||
// Reorder outputs from 2×2 tiles to row-major
|
||||
// acc[01] = [r0c0, r0c1, r1c0, r1c1, r0c2, r0c3, r1c2, r1c3]
|
||||
// acc[23] = [r2c0, r2c1, r3c0, r3c1, r2c2, r2c3, r3c2, r3c3]
|
||||
|
||||
svint32_t row01 = svtbl_s32(acc_01, idx);
|
||||
svint32_t row23 = svtbl_s32(acc_23, idx);
|
||||
|
||||
svfloat16_t temp1 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) a_ptr->d);
|
||||
svfloat16_t temp2 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) b_ptr->d);
|
||||
svfloat32_t sv_a_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp1, temp1)), idx_sc1);
|
||||
svfloat32_t sv_b_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp2, temp2)), idx_sc2);
|
||||
|
||||
acc_f32_01 = svmla_f32_x(svptrue_b32(), acc_f32_01, svcvt_f32_s32_x(svptrue_b32(), row01), svmul_lane_f32(sv_b_d, sv_a_d, 0));
|
||||
acc_f32_23 = svmla_f32_x(svptrue_b32(), acc_f32_23, svcvt_f32_s32_x(svptrue_b32(), row23), svmul_lane_f32(sv_b_d, sv_a_d, 2));
|
||||
a_ptr++;
|
||||
b_ptr++;
|
||||
}
|
||||
|
||||
svbool_t pg4 = svptrue_pat_b32(SV_VL4);
|
||||
svst1_f32(pg4, s + (y+0) * bs + x, acc_f32_01);
|
||||
svst1_f32(pg4, s + (y+1) * bs + x, svext_f32(acc_f32_01, acc_f32_01, 4));
|
||||
svst1_f32(pg4, s + (y+2) * bs + x, acc_f32_23);
|
||||
svst1_f32(pg4, s + (y+3) * bs + x, svext_f32(acc_f32_23, acc_f32_23, 4));
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif // SVE compile-time end
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
|
||||
|
||||
|
||||
@@ -1245,6 +1245,12 @@ void ggml_compute_forward_mul_mat(
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
const int32_t hint = ggml_get_op_params_i32(dst, 1);
|
||||
if (hint == GGML_HINT_SRC0_IS_HADAMARD && !params->use_ref) {
|
||||
ggml_compute_forward_fwht(params, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int ith = params->ith;
|
||||
@@ -2959,6 +2965,45 @@ struct ggml_cplan ggml_graph_plan(
|
||||
return cplan;
|
||||
}
|
||||
|
||||
|
||||
// Try to fuse the current node with subsequent nodes for better performance.
|
||||
// Returns the number of nodes skipped by fusion (>=1), or 0 if no fusion was applied.
|
||||
static bool ggml_cpu_disable_fusion = false; // initialized once in ggml_cpu_init(), read-only afterwards
|
||||
|
||||
static int ggml_cpu_try_fuse_ops(
|
||||
const struct ggml_cgraph * cgraph,
|
||||
const int node_n,
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_cplan * cplan) {
|
||||
|
||||
if (ggml_cpu_disable_fusion || cplan->use_ref) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct ggml_tensor * node = cgraph->nodes[node_n];
|
||||
|
||||
if (node->op == GGML_OP_RMS_NORM) {
|
||||
// RMS_NORM + MUL fusion
|
||||
const enum ggml_op fuse_ops[] = { GGML_OP_RMS_NORM, GGML_OP_MUL };
|
||||
if (ggml_can_fuse(cgraph, node_n, fuse_ops, 2)) {
|
||||
struct ggml_tensor * mul_node = cgraph->nodes[node_n + 1];
|
||||
const struct ggml_tensor * mul_w = (mul_node->src[0] == node)
|
||||
? mul_node->src[1] : mul_node->src[0];
|
||||
if (node->src[0]->type == GGML_TYPE_F32 &&
|
||||
mul_node->type == GGML_TYPE_F32 &&
|
||||
mul_w->type == GGML_TYPE_F32 &&
|
||||
mul_w->ne[0] == node->ne[0] &&
|
||||
mul_w->nb[0] == sizeof(float)) {
|
||||
|
||||
ggml_compute_forward_rms_norm_mul_fused(params, node, mul_node);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
|
||||
struct ggml_threadpool * tp = state->threadpool;
|
||||
@@ -2995,7 +3040,14 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_compute_forward(¶ms, node);
|
||||
// TODO: move fused-op detection into ggml_graph_plan so fusion decisions are made once at planning time
|
||||
// Try fused ops, fall back to normal compute
|
||||
const int n_fused = ggml_cpu_try_fuse_ops(cgraph, node_n, ¶ms, cplan);
|
||||
if (n_fused > 0) {
|
||||
node_n += n_fused;
|
||||
} else {
|
||||
ggml_compute_forward(¶ms, node);
|
||||
}
|
||||
|
||||
if (state->ith == 0 && cplan->abort_callback &&
|
||||
cplan->abort_callback(cplan->abort_callback_data)) {
|
||||
@@ -3757,6 +3809,11 @@ void ggml_cpu_init(void) {
|
||||
ggml_init_riscv_arch_features();
|
||||
#endif
|
||||
|
||||
{
|
||||
const char * env = getenv("GGML_CPU_DISABLE_FUSION");
|
||||
ggml_cpu_disable_fusion = (env != NULL && atoi(env) == 1);
|
||||
}
|
||||
|
||||
is_first_call = false;
|
||||
}
|
||||
|
||||
|
||||
@@ -195,8 +195,8 @@ static const struct ggml_backend_i ggml_backend_cpu_i = {
|
||||
/* .free = */ ggml_backend_cpu_free,
|
||||
/* .set_tensor_async = */ NULL,
|
||||
/* .get_tensor_async = */ NULL,
|
||||
/* .get_tensor_2d_async = */ NULL,
|
||||
/* .set_tensor_2d_async = */ NULL,
|
||||
/* .get_tensor_2d_async = */ NULL,
|
||||
/* .cpy_tensor_async = */ NULL,
|
||||
/* .synchronize = */ NULL,
|
||||
/* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create,
|
||||
|
||||
@@ -2321,6 +2321,9 @@ class tinyBLAS_Q0_PPC {
|
||||
}
|
||||
|
||||
void matmul(int64_t m, int64_t n) {
|
||||
#if defined(_AIX) || defined(__BIG_ENDIAN__)
|
||||
mnpack(0, m, 0, n);
|
||||
#else
|
||||
const int64_t mc = 64;
|
||||
const int64_t kc = 64;
|
||||
int64_t nc = 64;
|
||||
@@ -2334,7 +2337,6 @@ class tinyBLAS_Q0_PPC {
|
||||
} else {
|
||||
n_aligned = (n / 64) * 64;
|
||||
}
|
||||
|
||||
if (n_aligned > 0) {
|
||||
if (n_aligned % 64 == 0) nc = 64;
|
||||
else if (n_aligned == n) nc = n;
|
||||
@@ -2352,6 +2354,7 @@ class tinyBLAS_Q0_PPC {
|
||||
} else {
|
||||
mnpack(0, m, 0, n);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -3191,12 +3194,16 @@ class tinyBLAS_PPC {
|
||||
}
|
||||
|
||||
void matmul(int64_t m, int64_t n) {
|
||||
#if defined(_AIX) || defined(__BIG_ENDIAN__)
|
||||
mnpack(0, m, 0, n);
|
||||
#else
|
||||
int64_t mc = 256; int64_t nc = 256; int64_t kc = 256;
|
||||
if (m % mc == 0 && n % nc == 0 && k % kc == 0) {
|
||||
matmul_tiled(m, n, mc, nc, kc);
|
||||
} else {
|
||||
mnpack(0, m, 0, n);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@@ -3713,11 +3713,27 @@ void ggml_compute_forward_norm(
|
||||
|
||||
// ggml_compute_forward_group_rms_norm
|
||||
|
||||
// fusion kinds that can be combined with the rms_norm computation in a single pass.
|
||||
// extend this enum when adding new fused variants (e.g. FUSE_ADD, FUSE_MUL_ADD, ...).
|
||||
enum ggml_rms_norm_fuse_op {
|
||||
GGML_RMS_NORM_FUSE_OP_NONE,
|
||||
GGML_RMS_NORM_FUSE_OP_MUL,
|
||||
};
|
||||
|
||||
template <ggml_rms_norm_fuse_op FUSE_OP>
|
||||
static void ggml_compute_forward_rms_norm_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
ggml_tensor * dst_rms_norm,
|
||||
ggml_tensor * dst_fused = nullptr) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src0 = dst_rms_norm->src[0];
|
||||
const ggml_tensor * src1 = nullptr;
|
||||
ggml_tensor * dst = dst_rms_norm;
|
||||
|
||||
if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) {
|
||||
src1 = (dst_fused->src[0] == dst_rms_norm) ? dst_fused->src[1] : dst_fused->src[0];
|
||||
dst = dst_fused;
|
||||
}
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
@@ -3726,11 +3742,10 @@ static void ggml_compute_forward_rms_norm_f32(
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
memcpy(&eps, dst_rms_norm->op_params, sizeof(float));
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
// TODO: optimize
|
||||
@@ -3740,25 +3755,32 @@ static void ggml_compute_forward_rms_norm_f32(
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
// worth switching to explicit SIMD?
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
sum += (ggml_float)(x[i00] * x[i00]);
|
||||
}
|
||||
|
||||
const float mean = sum/ne00;
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
memcpy(y, x, ne00 * sizeof(float));
|
||||
// for (int i00 = 0; i00 < ne00; i00++) {
|
||||
// y[i00] = x[i00];
|
||||
// }
|
||||
|
||||
const float mean = sum/ne00;
|
||||
const float scale = 1.0f/sqrtf(mean + eps);
|
||||
|
||||
// if you hit this, likely you got an inf somewhere earlier
|
||||
assert(scale > 0.0f);
|
||||
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) {
|
||||
const int64_t i11 = i01 % ne11;
|
||||
const int64_t i12 = i02 % ne12;
|
||||
const int64_t i13 = i03 % ne13;
|
||||
const float * w = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
|
||||
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
y[i00] = x[i00] * scale * w[i00];
|
||||
}
|
||||
} else {
|
||||
memcpy(y, x, ne00 * sizeof(float));
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3773,7 +3795,31 @@ void ggml_compute_forward_rms_norm(
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_rms_norm_f32(params, dst);
|
||||
ggml_compute_forward_rms_norm_f32<GGML_RMS_NORM_FUSE_OP_NONE>(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fused RMS_NORM + MUL: computes dst = rms_norm(src0) * src1 in a single pass.
|
||||
// This avoids materializing the intermediate rms_norm result in memory.
|
||||
void ggml_compute_forward_rms_norm_mul_fused(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst_rms_norm,
|
||||
ggml_tensor * dst_mul) {
|
||||
|
||||
GGML_ASSERT(dst_mul != nullptr);
|
||||
GGML_ASSERT(dst_mul->src[0] == dst_rms_norm || dst_mul->src[1] == dst_rms_norm);
|
||||
|
||||
const ggml_tensor * src0 = dst_rms_norm->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_rms_norm_f32<GGML_RMS_NORM_FUSE_OP_MUL>(params, dst_rms_norm, dst_mul);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
@@ -11212,3 +11258,91 @@ void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_fwht_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t n = ne10;
|
||||
GGML_ASSERT((n & (n - 1)) == 0); // must be power of 2
|
||||
|
||||
const int64_t nr = ne11 * ne12 * ne13;
|
||||
const int64_t rows_per_thread = (nr + nth - 1) / nth;
|
||||
const int64_t start_row = ith * rows_per_thread;
|
||||
const int64_t end_row = MIN(start_row + rows_per_thread, nr);
|
||||
|
||||
const float scale = 1.0f / sqrtf((float)n);
|
||||
|
||||
#if defined(GGML_SIMD)
|
||||
const GGML_F32_VEC v_minus_one = GGML_F32_VEC_SET1(-1.0f);
|
||||
#endif
|
||||
|
||||
for (int64_t r = start_row; r < end_row; r++) {
|
||||
const int64_t i13 = r / (ne11 * ne12);
|
||||
const int64_t i12 = (r - i13 * ne11 * ne12) / ne11;
|
||||
const int64_t i11 = r - i13 * ne11 * ne12 - i12 * ne11;
|
||||
|
||||
const float * src_row = (const float *) ((const char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13);
|
||||
float * dst_row = (float *) ((char *) dst->data + i11 * nb1 + i12 * nb2 + i13 * nb3);
|
||||
|
||||
for (int64_t j = 0; j < n; j++) {
|
||||
dst_row[j] = src_row[j] * scale;
|
||||
}
|
||||
|
||||
// Scalar passes
|
||||
#if defined(GGML_SIMD)
|
||||
const int step = GGML_F32_EPR;
|
||||
#else
|
||||
const int step = n;
|
||||
#endif
|
||||
for (int64_t len = 1; len < step && len < n; len <<= 1) {
|
||||
for (int64_t i = 0; i < n; i += 2 * len) {
|
||||
for (int64_t j = 0; j < len; j++) {
|
||||
float u = dst_row[i + j];
|
||||
float v = dst_row[i + len + j];
|
||||
dst_row[i + j] = u + v;
|
||||
dst_row[i + len + j] = u - v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SIMD passes using GGML_F32_VEC_* macros for multi-architecture support
|
||||
#if defined(GGML_SIMD)
|
||||
for (int64_t len = step; len < n; len <<= 1) {
|
||||
for (int64_t i = 0; i < n; i += 2 * len) {
|
||||
for (int64_t j = 0; j < len; j += step) {
|
||||
GGML_F32_VEC u = GGML_F32_VEC_LOAD(dst_row + i + j);
|
||||
GGML_F32_VEC v = GGML_F32_VEC_LOAD(dst_row + i + len + j);
|
||||
|
||||
GGML_F32_VEC_STORE(dst_row + i + j, GGML_F32_VEC_ADD(u, v));
|
||||
GGML_F32_VEC_STORE(dst_row + i + len + j, GGML_F32_VEC_FMA(u, v, v_minus_one));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_fwht(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_fwht_f32(params, dst);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error - fwht is F32 only");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,6 +44,7 @@ void ggml_compute_forward_concat(const struct ggml_compute_params * params, stru
|
||||
void ggml_compute_forward_silu_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_rms_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_rms_norm_mul_fused(const struct ggml_compute_params * params, struct ggml_tensor * dst_rms_norm, struct ggml_tensor * dst_mul);
|
||||
void ggml_compute_forward_rms_norm_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_group_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_l2_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
@@ -111,6 +112,7 @@ void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params *
|
||||
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_fwht(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -830,6 +830,18 @@ static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) {
|
||||
#endif // defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ uint8_t ggml_cuda_fp32_to_ue4m3(float x) {
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE) // This is used for NVFP4 subblock scale quantizations only
|
||||
if (!(x > 0.0f)) {
|
||||
return 0;
|
||||
}
|
||||
const __nv_fp8_e4m3 xf(x);
|
||||
return xf.__x;
|
||||
#else
|
||||
NO_DEVICE_CODE; // Used only for NVFP4 Scales for Activations, only for Blackwell
|
||||
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
|
||||
const uint8_t sign_bit = (x < 0.0f) << 3;
|
||||
float ax = fabsf(x) * e;
|
||||
|
||||
@@ -66,6 +66,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
@@ -85,6 +88,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
@@ -118,6 +124,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 64, 160, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 64, 160, 128, 64, 2, false);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
|
||||
@@ -1217,7 +1226,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
float KQ_max_scale[cols_per_thread];
|
||||
#pragma unroll
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col);
|
||||
const int jc = (threadIdx.y/np)*cols_per_warp + (cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col));
|
||||
const float sink = sinks_f[jc % ncols2];
|
||||
|
||||
const float KQ_max_new = fmaxf(KQ_max[col], sink);
|
||||
@@ -1825,6 +1834,10 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
||||
|
||||
// Mistral Small 4 (DKQ=320, DV=256), GQA=32-only build:
|
||||
extern DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32);
|
||||
extern DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32);
|
||||
|
||||
// For GLM 4.7 Flash
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
|
||||
@@ -38,6 +38,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
|
||||
} break;
|
||||
case 320: {
|
||||
GGML_ASSERT(V->ne[0] == 256);
|
||||
ggml_cuda_flash_attn_ext_tile_case<320, 256>(ctx, dst);
|
||||
} break;
|
||||
case 512: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst);
|
||||
|
||||
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
|
||||
@@ -128,6 +130,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64)
|
||||
@@ -195,6 +199,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 512, 1, 128, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
|
||||
@@ -264,6 +270,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 128, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64)
|
||||
@@ -1116,7 +1124,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
|
||||
#ifdef GGML_USE_HIP
|
||||
if constexpr (DV <= 128) {
|
||||
if constexpr (DKQ <= 128) {
|
||||
if (Q->ne[1] > 32/ncols2) {
|
||||
constexpr int cols_per_block = 64;
|
||||
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
@@ -1130,7 +1138,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
|
||||
#endif // GGML_USE_HIP
|
||||
|
||||
#ifndef GGML_USE_HIP
|
||||
if constexpr (DV <= 256)
|
||||
if constexpr (DKQ <= 256)
|
||||
#endif // GGML_USE_HIP
|
||||
{
|
||||
if (Q->ne[1] > 16/ncols2) {
|
||||
@@ -1144,14 +1152,16 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
|
||||
}
|
||||
}
|
||||
|
||||
if (Q->ne[1] > 8/ncols2) {
|
||||
constexpr int cols_per_block = 16;
|
||||
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
|
||||
return;
|
||||
if constexpr (ncols2 <= 16) {
|
||||
if (Q->ne[1] > 8/ncols2) {
|
||||
constexpr int cols_per_block = 16;
|
||||
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (ncols2 <= 8) {
|
||||
@@ -1210,6 +1220,25 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
if constexpr (DKQ == 320) {
|
||||
// This branch is only used for Mistral Small 4 which has a GQA ratio of 32.
|
||||
// On AMD, simply use that GQA ratio with 32 columns / block since we always have enough SRAM.
|
||||
// On NVIDIA however, the tile kernel is only used for GPUs that can't use the mma kernel (Pascal and older).
|
||||
// Therefore, use a GQA ratio of 16 with 16 columns / block to stay below 48 kiB of SRAM / block.
|
||||
#ifdef GGML_USE_HIP
|
||||
if (use_gqa_opt && gqa_ratio % 32 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 32, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
#else
|
||||
if (use_gqa_opt && gqa_ratio % 16 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
#endif // GGML_USE_HIP
|
||||
GGML_ABORT("flash-attn tile (320/256): expected GQA ratio multiple of 32");
|
||||
}
|
||||
|
||||
if constexpr (DKQ == 576) {
|
||||
if (use_gqa_opt && gqa_ratio % 16 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||
@@ -1221,7 +1250,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (DKQ <= 512) {
|
||||
if constexpr (DKQ <= 512 && DKQ != 320) {
|
||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
@@ -1275,5 +1304,6 @@ extern DECL_FATTN_TILE_CASE( 96, 96);
|
||||
extern DECL_FATTN_TILE_CASE(112, 112);
|
||||
extern DECL_FATTN_TILE_CASE(128, 128);
|
||||
extern DECL_FATTN_TILE_CASE(256, 256);
|
||||
extern DECL_FATTN_TILE_CASE(320, 256);
|
||||
extern DECL_FATTN_TILE_CASE(512, 512);
|
||||
extern DECL_FATTN_TILE_CASE(576, 512);
|
||||
|
||||
@@ -143,6 +143,22 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
GGML_ASSERT(V->ne[0] == 256);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
|
||||
break;
|
||||
case 320:
|
||||
// For Mistral Small 4, go straight to the ncols1 switch (ncols2=32-only build).
|
||||
GGML_ASSERT(V->ne[0] == 256);
|
||||
{
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f;
|
||||
GGML_ASSERT(use_gqa_opt);
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
GGML_ASSERT(gqa_ratio % 32 == 0);
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<320, 256, 32>(ctx, dst);
|
||||
}
|
||||
break;
|
||||
case 512:
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst);
|
||||
@@ -352,6 +368,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
case 320:
|
||||
if (V->ne[0] != 256 || !gqa_opt_applies) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (gqa_ratio % 32 != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
case 512:
|
||||
if (V->ne[0] != K->ne[0]) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
|
||||
@@ -6,17 +6,18 @@ template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static __global__ void k_get_rows(
|
||||
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
|
||||
/*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
|
||||
/*const int64_t ne10,*/ const int64_t ne11, const uint3 ne12_fdv, /*const int64_t ne13,*/
|
||||
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
|
||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||
|
||||
for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
|
||||
for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) {
|
||||
for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
const int i10 = blockIdx.x;
|
||||
const int i11 = z / ne12; // TODO fastdiv
|
||||
const int i12 = z % ne12;
|
||||
const uint2 dm = fast_div_modulo((uint32_t)z, ne12_fdv);
|
||||
const int i11 = dm.x;
|
||||
const int i12 = dm.y;
|
||||
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
@@ -42,17 +43,18 @@ template<typename src0_t, typename dst_t>
|
||||
static __global__ void k_get_rows_float(
|
||||
const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
|
||||
/*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
|
||||
/*const int64_t ne10,*/ const int64_t ne11, const uint3 ne12_fdv, /*const int64_t ne13,*/
|
||||
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
|
||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||
|
||||
for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
|
||||
for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) {
|
||||
for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
const int i10 = blockIdx.x;
|
||||
const int i11 = z / ne12; // TODO fastdiv
|
||||
const int i12 = z % ne12;
|
||||
const uint2 dm = fast_div_modulo((uint32_t)z, ne12_fdv);
|
||||
const int i11 = dm.x;
|
||||
const int i12 = dm.y;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
@@ -115,10 +117,14 @@ static void get_rows_cuda_q(
|
||||
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
|
||||
GGML_ASSERT(ne12 > 0);
|
||||
GGML_ASSERT(ne11 <= std::numeric_limits<uint32_t>::max() / ne12);
|
||||
const uint3 ne12_fdv = init_fastdiv_values(ne12);
|
||||
|
||||
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, /*ne01, ne02, ne03,*/
|
||||
/*ne10,*/ ne11, ne12, /*ne13,*/
|
||||
/*ne10,*/ ne11, ne12_fdv, /*ne13,*/
|
||||
/* s0,*/ s1, s2, s3,
|
||||
/* nb00,*/ nb01, nb02, nb03,
|
||||
s10, s11, s12/*, s13*/);
|
||||
@@ -146,10 +152,14 @@ static void get_rows_cuda_float(
|
||||
const size_t s12 = nb12 / sizeof(int32_t);
|
||||
// const size_t s13 = nb13 / sizeof(int32_t);
|
||||
|
||||
GGML_ASSERT(ne12 > 0);
|
||||
GGML_ASSERT(ne11 <= std::numeric_limits<uint32_t>::max() / ne12);
|
||||
const uint3 ne12_fdv = init_fastdiv_values(ne12);
|
||||
|
||||
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, /*ne01, ne02, ne03,*/
|
||||
/*ne10,*/ ne11, ne12, /*ne13,*/
|
||||
/*ne10,*/ ne11, ne12_fdv, /*ne13,*/
|
||||
/* s0,*/ s1, s2, s3,
|
||||
/* nb00,*/ nb01, nb02, nb03,
|
||||
s10, s11, s12/*, s13*/);
|
||||
|
||||
@@ -3556,6 +3556,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
|
||||
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {
|
||||
const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * silu = cgraph->nodes[node_idx+1];
|
||||
if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
@@ -3564,6 +3567,31 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_ADD
|
||||
&& ops.begin()[2] == GGML_OP_UNARY && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {
|
||||
const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * add = cgraph->nodes[node_idx+1];
|
||||
const ggml_tensor * silu = cgraph->nodes[node_idx+2];
|
||||
if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ssm_conv->type != GGML_TYPE_F32 || add->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// ADD must consume ssm_conv's output and broadcast a 1-D channel-wise bias.
|
||||
const ggml_tensor * bias = (add->src[0] == ssm_conv) ? add->src[1] : add->src[0];
|
||||
if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) {
|
||||
return false;
|
||||
}
|
||||
if (ggml_nelements(bias) != ssm_conv->ne[0] || bias->ne[0] != ssm_conv->ne[0]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL
|
||||
&& unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) {
|
||||
const ggml_tensor * unary = cgraph->nodes[node_idx];
|
||||
@@ -3640,6 +3668,362 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
|
||||
return false;
|
||||
}
|
||||
|
||||
// try and fuse nodes and return the number of nodes to skip
|
||||
static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, int i) {
|
||||
|
||||
static bool disable_fusion = getenv("GGML_CUDA_DISABLE_FUSION") != nullptr && std::atoi(getenv("GGML_CUDA_DISABLE_FUSION"));
|
||||
if (disable_fusion) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
//topk-moe
|
||||
if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
|
||||
cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
|
||||
ggml_cuda_topk_moe_args args;
|
||||
const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);
|
||||
std::vector<ggml_op> ops;
|
||||
|
||||
if (can_fuse) {
|
||||
const ggml_tensor * logits = node->src[0];
|
||||
ggml_tensor * weights = nullptr;
|
||||
ggml_tensor * ids = nullptr;
|
||||
const ggml_tensor * bias = nullptr;
|
||||
const ggml_tensor * clamp = nullptr;
|
||||
const ggml_tensor * scale = nullptr;
|
||||
|
||||
if (!args.delayed_softmax) {
|
||||
ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
|
||||
int out_nodes[2]; // nodes which can't be elided
|
||||
|
||||
if (args.prob_bias) {
|
||||
bias = cgraph->nodes[i + 2]->src[1];
|
||||
ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT, GGML_OP_VIEW,
|
||||
GGML_OP_GET_ROWS });
|
||||
out_nodes[0] = i + 4;
|
||||
ids = cgraph->nodes[i + 4];
|
||||
} else {
|
||||
ops.insert(ops.end(),
|
||||
{ gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS });
|
||||
out_nodes[0] = i + 3;
|
||||
ids = cgraph->nodes[i + 3];
|
||||
}
|
||||
|
||||
if (args.norm) {
|
||||
ops.insert(ops.end(),
|
||||
{ GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, GGML_OP_RESHAPE });
|
||||
clamp = cgraph->nodes[i + ops.size() - 3];
|
||||
}
|
||||
if (args.scale) {
|
||||
ops.insert(ops.end(), { GGML_OP_SCALE });
|
||||
scale = cgraph->nodes[i + ops.size() - 1];
|
||||
}
|
||||
|
||||
weights = cgraph->nodes[i + ops.size() - 1];
|
||||
out_nodes[1] = i + ops.size() - 1;
|
||||
|
||||
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
|
||||
ggml_cuda_should_use_topk_moe(node, logits, weights, ids) &&
|
||||
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/true)) {
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
|
||||
return ops.size() - 1;
|
||||
}
|
||||
} else if (!args.norm && !args.prob_bias) {
|
||||
//special case gpt-oss, no norm, no bias.
|
||||
ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
|
||||
weights = cgraph->nodes[i + 5];
|
||||
ids = cgraph->nodes[i + 1];
|
||||
const ggml_tensor * softmax = cgraph->nodes[i + 4];
|
||||
|
||||
int out_nodes[2] = { i + 1, i + 5 };
|
||||
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
|
||||
ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) &&
|
||||
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/true)) {
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
|
||||
return ops.size() - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//RoPE + view + set-rows
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
|
||||
ggml_tensor * rope = cgraph->nodes[i];
|
||||
ggml_tensor * set_rows = cgraph->nodes[i + 2];
|
||||
|
||||
ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows);
|
||||
return 2;
|
||||
}
|
||||
|
||||
// multi-(add or mul)
|
||||
if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) {
|
||||
int n_fuse = 0;
|
||||
ggml_op ops[8];
|
||||
std::fill(ops, ops + 8, node->op);
|
||||
|
||||
for (; n_fuse <= 6; ++n_fuse) {
|
||||
if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
|
||||
break;
|
||||
}
|
||||
if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
|
||||
break;
|
||||
}
|
||||
if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
n_fuse++;
|
||||
|
||||
if (n_fuse > 1) {
|
||||
ggml_tensor fused_node;
|
||||
memcpy(&fused_node, node, sizeof(ggml_tensor));
|
||||
for (int j = 0; j < n_fuse - 1; ++j) {
|
||||
fused_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
|
||||
}
|
||||
fused_node.data = cgraph->nodes[i + n_fuse - 1]->data;
|
||||
if (node->op == GGML_OP_ADD) {
|
||||
ggml_cuda_op_fused_add(*cuda_ctx, &fused_node, n_fuse);
|
||||
} else {
|
||||
ggml_cuda_op_fused_mul(*cuda_ctx, &fused_node, n_fuse);
|
||||
}
|
||||
return n_fuse - 1;
|
||||
}
|
||||
}
|
||||
|
||||
bool fused_mul_mat_vec = false;
|
||||
int fused_node_count = 0;
|
||||
|
||||
// gate + glu + up
|
||||
for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
|
||||
const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) {
|
||||
ggml_tensor * glu = cgraph->nodes[i + 4];
|
||||
ggml_tensor * gate_bias_n = glu->src[0];
|
||||
ggml_tensor * up_bias_n = glu->src[1];
|
||||
|
||||
//we don't assume the order for {gate, up}. Instead infer it from the bias tensor
|
||||
ggml_tensor * gate_n = nullptr;
|
||||
ggml_tensor * up_n = nullptr;
|
||||
|
||||
if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) {
|
||||
gate_n = cgraph->nodes[i];
|
||||
up_n = cgraph->nodes[i + 2];
|
||||
} else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) {
|
||||
gate_n = cgraph->nodes[i + 2];
|
||||
up_n = cgraph->nodes[i];
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) {
|
||||
if (op_bias == GGML_OP_ADD) {
|
||||
if (bias_node->src[0] == mul_node) {
|
||||
return bias_node->src[1];
|
||||
}
|
||||
if (bias_node->src[1] == mul_node) {
|
||||
return bias_node->src[0];
|
||||
}
|
||||
return (ggml_tensor *) nullptr;
|
||||
}
|
||||
GGML_ASSERT(op_bias == GGML_OP_ADD_ID);
|
||||
GGML_ASSERT(bias_node->src[0] == mul_node);
|
||||
return bias_node->src[1];
|
||||
};
|
||||
|
||||
ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op);
|
||||
ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op);
|
||||
|
||||
if (!up_bias_tensor || !gate_bias_tensor) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// we don't support repeating adds
|
||||
if (bias_op == GGML_OP_ADD && (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) ||
|
||||
!ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const ggml_tensor * src0 = up_n->src[0];
|
||||
const ggml_tensor * src1 = up_n->src[1];
|
||||
const ggml_tensor * ids = up_n->src[2];
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) {
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.gate = gate_n->src[0];
|
||||
fusion_data.x_bias = up_bias_tensor;
|
||||
fusion_data.gate_bias = gate_bias_tensor;
|
||||
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||
|
||||
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 5;
|
||||
break;
|
||||
}
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) {
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.gate = gate_n->src[0];
|
||||
fusion_data.x_bias = up_bias_tensor;
|
||||
fusion_data.gate_bias = gate_bias_tensor;
|
||||
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||
|
||||
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 5;
|
||||
break;
|
||||
}
|
||||
} else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
|
||||
ggml_tensor * glu = cgraph->nodes[i + 2];
|
||||
ggml_tensor * gate = glu->src[0];
|
||||
ggml_tensor * up = glu->src[1];
|
||||
|
||||
bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1]) ||
|
||||
(gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]);
|
||||
|
||||
if (!ok) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const ggml_tensor * src0 = up->src[0];
|
||||
const ggml_tensor * src1 = up->src[1];
|
||||
const ggml_tensor * ids = up->src[2];
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.gate = gate->src[0];
|
||||
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||
|
||||
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 3;
|
||||
break;
|
||||
}
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.gate = gate->src[0];
|
||||
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||
|
||||
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 3;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (fused_mul_mat_vec) {
|
||||
return fused_node_count - 1;
|
||||
}
|
||||
|
||||
fused_mul_mat_vec = false;
|
||||
fused_node_count = 0;
|
||||
|
||||
// gate + add + glu + up + add
|
||||
for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
|
||||
const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
|
||||
|
||||
if (!ggml_can_fuse(cgraph, i, { op, bias_op })) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_tensor * mm_node = cgraph->nodes[i];
|
||||
ggml_tensor * bias_node = cgraph->nodes[i + 1];
|
||||
|
||||
ggml_tensor * bias_tensor = nullptr;
|
||||
if (bias_op == GGML_OP_ADD) {
|
||||
if (bias_node->src[0] == mm_node) {
|
||||
bias_tensor = bias_node->src[1];
|
||||
} else if (bias_node->src[1] == mm_node) {
|
||||
bias_tensor = bias_node->src[0];
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (bias_node->src[0] != mm_node) {
|
||||
continue;
|
||||
}
|
||||
bias_tensor = bias_node->src[1];
|
||||
}
|
||||
|
||||
const ggml_tensor * src0 = mm_node->src[0];
|
||||
const ggml_tensor * src1 = mm_node->src[1];
|
||||
const ggml_tensor * ids = mm_node->src[2];
|
||||
|
||||
if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.x_bias = bias_tensor;
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) {
|
||||
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 2;
|
||||
break;
|
||||
}
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {
|
||||
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 2;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (fused_mul_mat_vec) {
|
||||
return fused_node_count - 1;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD }, {})) {
|
||||
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]);
|
||||
return 2;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
|
||||
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i + 1]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
|
||||
ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]);
|
||||
return 2;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
|
||||
ggml_cuda_op_ssm_conv(*cuda_ctx, node, /*bias_add_node=*/ nullptr, cgraph->nodes[i + 1]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) ||
|
||||
ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) ||
|
||||
ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) {
|
||||
ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i + 1]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_SQR }, { GGML_UNARY_OP_RELU })) {
|
||||
ggml_cuda_op_relu_sqr(*cuda_ctx, node, cgraph->nodes[i + 1]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
|
||||
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i + 2], node);
|
||||
return 2;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
|
||||
bool graph_evaluated_or_captured = false;
|
||||
|
||||
@@ -3786,355 +4170,11 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
||||
continue;
|
||||
}
|
||||
|
||||
// start of fusion operations
|
||||
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
|
||||
if (!disable_fusion) {
|
||||
ggml_cuda_topk_moe_args args;
|
||||
int nodes_to_skip = ggml_cuda_try_fuse(cuda_ctx, cgraph, i);
|
||||
|
||||
if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
|
||||
cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
|
||||
const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);
|
||||
|
||||
std::vector<ggml_op> ops;
|
||||
|
||||
if (can_fuse) {
|
||||
const ggml_tensor * logits = node->src[0];
|
||||
ggml_tensor * weights = nullptr;
|
||||
ggml_tensor * ids = nullptr;
|
||||
const ggml_tensor * bias = nullptr;
|
||||
const ggml_tensor * clamp = nullptr;
|
||||
const ggml_tensor * scale = nullptr;
|
||||
|
||||
if (!args.delayed_softmax) {
|
||||
ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
|
||||
int out_nodes[2]; // nodes which can't be elided
|
||||
|
||||
if (args.prob_bias) {
|
||||
bias = cgraph->nodes[i + 2]->src[1];
|
||||
ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS });
|
||||
out_nodes[0] = i + 4;
|
||||
ids = cgraph->nodes[i + 4];
|
||||
} else {
|
||||
ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW,
|
||||
GGML_OP_GET_ROWS });
|
||||
out_nodes[0] = i + 3;
|
||||
ids = cgraph->nodes[i + 3];
|
||||
}
|
||||
|
||||
if (args.norm) {
|
||||
ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
|
||||
GGML_OP_DIV, GGML_OP_RESHAPE });
|
||||
clamp = cgraph->nodes[i + ops.size() - 3];
|
||||
}
|
||||
if (args.scale) {
|
||||
ops.insert(ops.end(), { GGML_OP_SCALE });
|
||||
scale = cgraph->nodes[i + ops.size() - 1];
|
||||
}
|
||||
|
||||
weights = cgraph->nodes[i + ops.size() - 1];
|
||||
out_nodes[1] = i + ops.size() - 1;
|
||||
|
||||
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
|
||||
ggml_cuda_should_use_topk_moe(node, logits, weights, ids) &&
|
||||
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/ true)) {
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
|
||||
i += ops.size() - 1;
|
||||
continue;
|
||||
}
|
||||
} else if (!args.norm && !args.prob_bias) {
|
||||
//special case gpt-oss, no norm, no bias.
|
||||
ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
|
||||
GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
|
||||
weights = cgraph->nodes[i + 5];
|
||||
ids = cgraph->nodes[i + 1];
|
||||
const ggml_tensor * softmax = cgraph->nodes[i + 4];
|
||||
|
||||
int out_nodes[2] = { i + 1, i + 5 };
|
||||
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
|
||||
ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) &&
|
||||
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/ true)) {
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
|
||||
i += ops.size() - 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
|
||||
ggml_tensor * rope = cgraph->nodes[i];
|
||||
ggml_tensor * set_rows = cgraph->nodes[i + 2];
|
||||
|
||||
ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows);
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) {
|
||||
int n_fuse = 0;
|
||||
ggml_op ops[8];
|
||||
std::fill(ops, ops + 8, node->op);
|
||||
|
||||
for (; n_fuse <= 6; ++n_fuse){
|
||||
if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
|
||||
break;
|
||||
}
|
||||
if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
|
||||
break;
|
||||
}
|
||||
if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
n_fuse++;
|
||||
|
||||
if (n_fuse > 1) {
|
||||
ggml_tensor fused_node;
|
||||
memcpy(&fused_node, node, sizeof(ggml_tensor));
|
||||
for (int j = 0; j < n_fuse - 1; ++j) {
|
||||
fused_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
|
||||
}
|
||||
fused_node.data = cgraph->nodes[i + n_fuse - 1]->data;
|
||||
if (node->op == GGML_OP_ADD) {
|
||||
ggml_cuda_op_fused_add(*cuda_ctx, &fused_node, n_fuse);
|
||||
} else {
|
||||
ggml_cuda_op_fused_mul(*cuda_ctx, &fused_node, n_fuse);
|
||||
}
|
||||
i += n_fuse - 1;
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
bool fused_mul_mat_vec = false;
|
||||
int fused_node_count = 0;
|
||||
|
||||
for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
|
||||
const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) {
|
||||
ggml_tensor * glu = cgraph->nodes[i + 4];
|
||||
ggml_tensor * gate_bias_n = glu->src[0];
|
||||
ggml_tensor * up_bias_n = glu->src[1];
|
||||
|
||||
//we don't assume the order for {gate, up}. Instead infer it from the bias tensor
|
||||
ggml_tensor * gate_n = nullptr;
|
||||
ggml_tensor * up_n = nullptr;
|
||||
|
||||
if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) {
|
||||
gate_n = cgraph->nodes[i];
|
||||
up_n = cgraph->nodes[i + 2];
|
||||
} else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) {
|
||||
gate_n = cgraph->nodes[i + 2];
|
||||
up_n = cgraph->nodes[i];
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) {
|
||||
if (op_bias == GGML_OP_ADD) {
|
||||
if (bias_node->src[0] == mul_node) {
|
||||
return bias_node->src[1];
|
||||
}
|
||||
if (bias_node->src[1] == mul_node) {
|
||||
return bias_node->src[0];
|
||||
}
|
||||
return (ggml_tensor *) nullptr;
|
||||
}
|
||||
GGML_ASSERT(op_bias == GGML_OP_ADD_ID);
|
||||
GGML_ASSERT(bias_node->src[0] == mul_node);
|
||||
return bias_node->src[1];
|
||||
};
|
||||
|
||||
ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op);
|
||||
ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op);
|
||||
|
||||
if (!up_bias_tensor || !gate_bias_tensor) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// we don't support repeating adds
|
||||
if (bias_op == GGML_OP_ADD &&
|
||||
(!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) ||
|
||||
!ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const ggml_tensor * src0 = up_n->src[0];
|
||||
const ggml_tensor * src1 = up_n->src[1];
|
||||
const ggml_tensor * ids = up_n->src[2];
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) {
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.gate = gate_n->src[0];
|
||||
fusion_data.x_bias = up_bias_tensor;
|
||||
fusion_data.gate_bias = gate_bias_tensor;
|
||||
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||
|
||||
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 5;
|
||||
break;
|
||||
}
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) {
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.gate = gate_n->src[0];
|
||||
fusion_data.x_bias = up_bias_tensor;
|
||||
fusion_data.gate_bias = gate_bias_tensor;
|
||||
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||
|
||||
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 5;
|
||||
break;
|
||||
}
|
||||
} else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
|
||||
ggml_tensor * glu = cgraph->nodes[i + 2];
|
||||
ggml_tensor * gate = glu->src[0];
|
||||
ggml_tensor * up = glu->src[1];
|
||||
|
||||
bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1])
|
||||
|| (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]);
|
||||
|
||||
if (!ok) continue;
|
||||
|
||||
const ggml_tensor * src0 = up->src[0];
|
||||
const ggml_tensor * src1 = up->src[1];
|
||||
const ggml_tensor * ids = up->src[2];
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.gate = gate->src[0];
|
||||
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||
|
||||
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 3;
|
||||
break;
|
||||
}
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.gate = gate->src[0];
|
||||
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||
|
||||
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 3;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (fused_mul_mat_vec) {
|
||||
i += fused_node_count - 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
fused_mul_mat_vec = false;
|
||||
fused_node_count = 0;
|
||||
|
||||
for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
|
||||
const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
|
||||
|
||||
if (!ggml_can_fuse(cgraph, i, { op, bias_op })) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_tensor * mm_node = cgraph->nodes[i];
|
||||
ggml_tensor * bias_node = cgraph->nodes[i + 1];
|
||||
|
||||
ggml_tensor * bias_tensor = nullptr;
|
||||
if (bias_op == GGML_OP_ADD) {
|
||||
if (bias_node->src[0] == mm_node) {
|
||||
bias_tensor = bias_node->src[1];
|
||||
} else if (bias_node->src[1] == mm_node) {
|
||||
bias_tensor = bias_node->src[0];
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (bias_node->src[0] != mm_node) {
|
||||
continue;
|
||||
}
|
||||
bias_tensor = bias_node->src[1];
|
||||
}
|
||||
|
||||
const ggml_tensor * src0 = mm_node->src[0];
|
||||
const ggml_tensor * src1 = mm_node->src[1];
|
||||
const ggml_tensor * ids = mm_node->src[2];
|
||||
|
||||
if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.x_bias = bias_tensor;
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) {
|
||||
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 2;
|
||||
break;
|
||||
}
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {
|
||||
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 2;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (fused_mul_mat_vec) {
|
||||
i += fused_node_count - 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
|
||||
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
|
||||
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
|
||||
ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i+1]);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) ||
|
||||
ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) ||
|
||||
ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) {
|
||||
ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i+1]);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_SQR }, { GGML_UNARY_OP_RELU })) {
|
||||
ggml_cuda_op_relu_sqr(*cuda_ctx, node, cgraph->nodes[i+1]);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
|
||||
i += 2;
|
||||
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
|
||||
continue;
|
||||
}
|
||||
if (nodes_to_skip != 0) {
|
||||
i += nodes_to_skip;
|
||||
continue;
|
||||
}
|
||||
#ifndef NDEBUG
|
||||
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
|
||||
@@ -4548,8 +4588,8 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
|
||||
/* .free = */ ggml_backend_cuda_free,
|
||||
/* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async,
|
||||
/* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async,
|
||||
/* .get_tensor_2d_async = */ ggml_backend_cuda_set_tensor_2d_async,
|
||||
/* .set_tensor_2d_async = */ ggml_backend_cuda_get_tensor_2d_async,
|
||||
/* .set_tensor_2d_async = */ ggml_backend_cuda_set_tensor_2d_async,
|
||||
/* .get_tensor_2d_async = */ ggml_backend_cuda_get_tensor_2d_async,
|
||||
/* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async,
|
||||
/* .synchronize = */ ggml_backend_cuda_synchronize,
|
||||
/* .graph_plan_create = */ NULL,
|
||||
@@ -5391,8 +5431,8 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||
dev_ctx->description = prop.name;
|
||||
|
||||
char pci_bus_id[16] = {};
|
||||
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
|
||||
char pci_bus_id[32] = {};
|
||||
CUDA_CHECK(cudaDeviceGetPCIBusId(pci_bus_id, sizeof(pci_bus_id), i));
|
||||
dev_ctx->pci_bus_id = pci_bus_id;
|
||||
dev_ctx->op_offload_min_batch_size = min_batch_size;
|
||||
|
||||
|
||||
@@ -1015,25 +1015,35 @@ namespace ggml_cuda_mma {
|
||||
#endif // AMD_MFMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
|
||||
const tile<16, 8, int> & A,
|
||||
const tile<8, 8, int> & B,
|
||||
uint32_t a_scale,
|
||||
uint32_t b_scale) {
|
||||
template <ggml_type type>
|
||||
static __device__ __forceinline__ void mma_block_scaled_fp4(tile<16, 8, float> & D,
|
||||
const tile<16, 8, int> & A,
|
||||
const tile<8, 8, int> & B,
|
||||
uint32_t a_scale,
|
||||
uint32_t b_scale) {
|
||||
#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
const int * Axi = (const int *) A.x;
|
||||
const int * Bxi = (const int *) B.x;
|
||||
float * Dxi = (float *) D.x;
|
||||
|
||||
asm volatile(
|
||||
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
|
||||
"%10, {0, 0}, %11, {0, 0};"
|
||||
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
|
||||
if constexpr (type == GGML_TYPE_MXFP4) {
|
||||
asm volatile(
|
||||
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
|
||||
"%10, {0, 0}, %11, {0, 0};"
|
||||
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
|
||||
} else {
|
||||
asm volatile(
|
||||
"mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 "
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
|
||||
"%10, {0, 0}, %11, {0, 0};"
|
||||
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
|
||||
#endif // BLACKWELL_MMA_AVAILABLE
|
||||
#endif // BLACKWELL_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
|
||||
@@ -122,7 +122,7 @@ void ggml_cuda_mul_mat_q(
|
||||
|| GGML_CUDA_CC_IS_CDNA(cc);
|
||||
|
||||
// TODO: tighter pool buffer size vs q8 path
|
||||
const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
|
||||
const bool use_native_fp4 = blackwell_mma_available(cc) && (src0->type == GGML_TYPE_MXFP4 || src0->type == GGML_TYPE_NVFP4);
|
||||
|
||||
if (!ids) {
|
||||
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
|
||||
@@ -133,9 +133,9 @@ void ggml_cuda_mul_mat_q(
|
||||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s13 = src1->nb[3] / ts_src1;
|
||||
if (use_native_mxfp4) {
|
||||
if (use_native_fp4) {
|
||||
static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
|
||||
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
|
||||
quantize_mmq_fp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
|
||||
ne11, ne12, ne13, stream);
|
||||
|
||||
} else {
|
||||
@@ -146,10 +146,8 @@ void ggml_cuda_mul_mat_q(
|
||||
}
|
||||
|
||||
// Stride depends on quantization format
|
||||
const int64_t s12 = use_native_mxfp4 ?
|
||||
ne11 * ne10_padded * sizeof(block_fp4_mmq) /
|
||||
(8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32)
|
||||
:
|
||||
const int64_t s12 = use_native_fp4 ?
|
||||
ne11 * ne10_padded * sizeof(block_fp4_mmq) / (QK_K * sizeof(int)) : // block_fp4_mmq holds 256 values
|
||||
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
|
||||
const int64_t s13 = ne12*s12;
|
||||
|
||||
@@ -198,8 +196,8 @@ void ggml_cuda_mul_mat_q(
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s13 = src1->nb[3] / ts_src1;
|
||||
|
||||
if (use_native_mxfp4) {
|
||||
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
|
||||
if (use_native_fp4) {
|
||||
quantize_mmq_fp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
|
||||
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||
} else {
|
||||
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
|
||||
@@ -208,8 +206,9 @@ void ggml_cuda_mul_mat_q(
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
|
||||
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
|
||||
static_assert(QK_K == 8 * QK_MXFP4, "QK_K needs to be 8 * QK_MXFP4");
|
||||
const int64_t s12 = use_native_fp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (QK_K * sizeof(int)) :
|
||||
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
|
||||
const int64_t s13 = ne12*s12;
|
||||
|
||||
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
|
||||
|
||||
@@ -10,9 +10,9 @@
|
||||
using namespace ggml_cuda_mma;
|
||||
|
||||
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
|
||||
#define MMQ_ITER_K 256
|
||||
#define MMQ_ITER_K_MXFP4_FP4 512
|
||||
#define MMQ_NWARPS 8
|
||||
#define MMQ_ITER_K 256
|
||||
#define MMQ_ITER_K_FP4 512
|
||||
#define MMQ_NWARPS 8
|
||||
|
||||
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
|
||||
typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00);
|
||||
@@ -46,9 +46,12 @@ struct block_q8_1_mmq {
|
||||
int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
|
||||
};
|
||||
|
||||
// this struct is used for fp4 data types (currently only used for Blackwell)
|
||||
// mxfp4 has block size 32, each int32 of d4 contains 2 e8m0 scales in the lower 16 bits
|
||||
// nvfp4 has block size 16, each int32 of d4 contains 4 ue4m3 scales
|
||||
struct block_fp4_mmq {
|
||||
uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
|
||||
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
|
||||
uint32_t d4[4];
|
||||
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte)
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
|
||||
@@ -143,10 +146,11 @@ static int get_mmq_y_host(const int cc) {
|
||||
|
||||
static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE)
|
||||
return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
|
||||
#else
|
||||
return MMQ_ITER_K;
|
||||
if (type == GGML_TYPE_NVFP4 || type == GGML_TYPE_MXFP4) {
|
||||
return MMQ_ITER_K_FP4;
|
||||
}
|
||||
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
||||
return MMQ_ITER_K;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_mmq_y_device() {
|
||||
@@ -213,8 +217,8 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
||||
}
|
||||
|
||||
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4
|
||||
#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4
|
||||
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4 and NVFP4 Blackwell
|
||||
#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4 Generic
|
||||
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
|
||||
@@ -240,7 +244,11 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
||||
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
// tile sizes are the same for Q8_1 and FP4 for blackwell
|
||||
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE)
|
||||
case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_FP4;
|
||||
#else
|
||||
case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4;
|
||||
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
||||
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
|
||||
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
|
||||
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||
@@ -934,6 +942,128 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
template <int mmq_y, bool need_check>
|
||||
static __device__ __forceinline__ void load_tiles_nvfp4_nvfp4(const char * __restrict__ x,
|
||||
int * __restrict__ x_tile,
|
||||
const int kbx0,
|
||||
const int i_max,
|
||||
const int stride) {
|
||||
constexpr int nwarps = mmq_get_nwarps_device();
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int iter_k = get_iter_k(GGML_TYPE_NVFP4);
|
||||
constexpr int threads_per_row = iter_k / QK_NVFP4; // each thread processes 1 block
|
||||
constexpr int rows_per_warp = warp_size / threads_per_row;
|
||||
|
||||
uint32_t * x_u32 = (uint32_t *) x_tile;
|
||||
|
||||
const int txi = threadIdx.x;
|
||||
const int kbx = txi % threads_per_row;
|
||||
const int row_in_warp = txi / threads_per_row;
|
||||
|
||||
const block_nvfp4 * bxi_base = (const block_nvfp4 *) x + kbx0 + kbx;
|
||||
uint32_t * x_u32_scale = x_u32 + 64 + kbx;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
|
||||
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
|
||||
|
||||
if constexpr (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
|
||||
const block_nvfp4 * bxi = bxi_base + i * stride;
|
||||
const int row_base = i * MMQ_MMA_TILE_X_K_FP4;
|
||||
const int q_base = row_base + 8 * kbx;
|
||||
|
||||
const uint32_t * src_qs = reinterpret_cast<const uint32_t *>(bxi->qs);
|
||||
|
||||
#pragma unroll
|
||||
for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
|
||||
x_u32[q_base + 2 * sub + 0] = src_qs[2 * sub + 0];
|
||||
x_u32[q_base + 2 * sub + 1] = src_qs[2 * sub + 1];
|
||||
}
|
||||
|
||||
x_u32_scale[row_base] = get_int_b4(bxi->d, 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Shared MMA kernel for MXFP4 and NVFP4 on Blackwell.
|
||||
// Both quantizations encode values as e2m1 (FP4) and produce one uint32 scale per
|
||||
// m16n8k64 MMA call; only the PTX kind (scale_vec::2X ue8m0 vs scale_vec::4X ue4m3)
|
||||
// and the per-type stride constant differ.
|
||||
template <int mmq_x, int mmq_y, ggml_type type>
|
||||
static __device__ __forceinline__ void vec_dot_fp4_fp4_mma(const int * __restrict__ x,
|
||||
const int * __restrict__ y,
|
||||
float * __restrict__ sum,
|
||||
const int k00) {
|
||||
static_assert(type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4,
|
||||
"vec_dot_fp4_fp4_mma: type must be MXFP4 or NVFP4");
|
||||
|
||||
typedef tile<16, 8, int> tile_A;
|
||||
typedef tile<8, 8, int> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
|
||||
constexpr int stride = MMQ_MMA_TILE_X_K_FP4;
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = 2 * granularity;
|
||||
constexpr int ntx = rows_per_warp / tile_C::I;
|
||||
constexpr int nfrags = MMQ_TILE_NE_K / tile_A::J;
|
||||
|
||||
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_K);
|
||||
|
||||
const int * x_qs = (const int *) x;
|
||||
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const uint32_t * y_sc = (const uint32_t *) y;
|
||||
|
||||
// 2 threads per quad supply the packed scale register to the block_scale MMA,
|
||||
// see https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
|
||||
const int tidx_A = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
|
||||
const int tidx_B = threadIdx.x / 4;
|
||||
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
||||
|
||||
tile_A A[ntx][nfrags];
|
||||
uint32_t scaleA[ntx][nfrags];
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
#pragma unroll
|
||||
for (int frag = 0; frag < nfrags; ++frag) {
|
||||
const int k0 = k00 + frag * tile_A::J;
|
||||
load_ldmatrix(A[n][frag], x_qs + (i0 + n * tile_A::I) * stride + k0, stride);
|
||||
scaleA[n][frag] = x_sc[(i0 + n * tile_A::I + tidx_A) * stride + k0 / tile_A::J];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
|
||||
tile_B B[nfrags];
|
||||
uint32_t scaleB[nfrags];
|
||||
|
||||
#pragma unroll
|
||||
for (int frag = 0; frag < nfrags; ++frag) {
|
||||
const int k0 = frag * tile_B::J;
|
||||
load_generic(B[frag], y_qs + j0 * MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K);
|
||||
scaleB[frag] = y_sc[(j0 + tidx_B) * MMQ_TILE_Y_K + frag];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
#pragma unroll
|
||||
for (int frag = 0; frag < nfrags; ++frag) {
|
||||
tile_C C = {};
|
||||
mma_block_scaled_fp4<type>(C, A[n][frag], B[frag], scaleA[n][frag], scaleB[frag]);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // BLACKWELL_MMA_AVAILABLE
|
||||
|
||||
|
||||
template <int mmq_y, bool need_check>
|
||||
static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x,
|
||||
@@ -1163,77 +1293,6 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y>
|
||||
static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
|
||||
const int * __restrict__ y,
|
||||
float * __restrict__ sum,
|
||||
const int k00) {
|
||||
typedef tile<16, 8, int> tile_A;
|
||||
typedef tile<8, 8, int> tile_B;
|
||||
typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = 2 * granularity;
|
||||
constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
|
||||
|
||||
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
|
||||
|
||||
// Match layout from load_tiles_mxfp4_fp4
|
||||
const int * x_qs = (const int *) x;
|
||||
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const uint32_t * y_sc = (const uint32_t *) y;
|
||||
|
||||
// tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
|
||||
tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
|
||||
uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
|
||||
|
||||
// Block scale
|
||||
// Each thread has to point to a 4 byte scale value
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
|
||||
|
||||
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
|
||||
MMQ_MMA_TILE_X_K_FP4);
|
||||
|
||||
// based on block-scaling document, 2 threads in each quad need to supply to the scale value
|
||||
const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
|
||||
scaleA[n][k01 / (2 * QI_MXFP4)] =
|
||||
*(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
|
||||
tile_B B;
|
||||
uint32_t scaleB; // 2xN scales
|
||||
|
||||
load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
|
||||
|
||||
scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
tile_C C;
|
||||
|
||||
mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y>
|
||||
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
||||
@@ -3259,7 +3318,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
||||
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
|
||||
#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma<mmq_x, mmq_y, GGML_TYPE_MXFP4>;
|
||||
#else
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
||||
@@ -3270,8 +3329,13 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
||||
template <int mmq_x, int mmq_y, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_NVFP4> {
|
||||
static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ;
|
||||
#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4_nvfp4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma<mmq_x, mmq_y, GGML_TYPE_NVFP4>;
|
||||
#else
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
|
||||
#endif // BLACKWELL_MMA_AVAILABLE
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
|
||||
};
|
||||
|
||||
@@ -3406,7 +3470,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
||||
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE)
|
||||
// FP4 tile stores 8 blocks
|
||||
constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
|
||||
constexpr int ne_block = (type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4) ? QK_K : 4 * QK8_1;
|
||||
#else
|
||||
constexpr int ne_block = 4 * QK8_1;
|
||||
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
||||
|
||||
@@ -115,6 +115,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_pascal_older(gg
|
||||
case GGML_TYPE_IQ4_NL: return 6;
|
||||
case GGML_TYPE_IQ4_XS: return 5;
|
||||
case GGML_TYPE_MXFP4: return 4;
|
||||
case GGML_TYPE_NVFP4: return 4;
|
||||
case GGML_TYPE_Q2_K: return 4;
|
||||
case GGML_TYPE_Q3_K: return 4;
|
||||
case GGML_TYPE_Q4_0: return 6;
|
||||
@@ -135,6 +136,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_turing_plus(ggm
|
||||
case GGML_TYPE_IQ3_S: return 6;
|
||||
case GGML_TYPE_IQ3_XXS: return 7;
|
||||
case GGML_TYPE_MXFP4: return 7;
|
||||
case GGML_TYPE_NVFP4: return 8;
|
||||
case GGML_TYPE_Q2_K: return 7;
|
||||
case GGML_TYPE_Q3_K: return 5;
|
||||
default: return MMVQ_MAX_BATCH_SIZE;
|
||||
@@ -221,6 +223,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type
|
||||
case GGML_TYPE_IQ4_NL: return 7;
|
||||
case GGML_TYPE_IQ4_XS: return 5;
|
||||
case GGML_TYPE_MXFP4: return 5;
|
||||
case GGML_TYPE_NVFP4: return 5;
|
||||
case GGML_TYPE_Q3_K: return 4;
|
||||
case GGML_TYPE_Q4_0: return 7;
|
||||
case GGML_TYPE_Q4_1: return 7;
|
||||
|
||||
@@ -70,6 +70,102 @@ __device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) {
|
||||
return static_cast<uint8_t>(biased);
|
||||
}
|
||||
|
||||
|
||||
static __global__ void quantize_mmq_nvfp4(
|
||||
const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
|
||||
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2) {
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE)
|
||||
|
||||
const int64_t i0_base = ((int64_t) blockDim.x * blockIdx.y + threadIdx.x) * QK_NVFP4_SUB;
|
||||
if (i0_base >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i1 = blockIdx.x;
|
||||
const int64_t i2 = blockIdx.z % ne2;
|
||||
const int64_t i3 = blockIdx.z / ne2;
|
||||
const int64_t i01 = ids ? ids[i1] : i1;
|
||||
const int64_t k_block = i0_base / QK_K;
|
||||
const int64_t blocks_per_col = (ne0 + QK_K - 1) / QK_K;
|
||||
if (k_block >= blocks_per_col) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t ib = blockIdx.z * ((int64_t) blocks_per_col * ne1) + k_block * ne1 + blockIdx.x;
|
||||
block_fp4_mmq * y = (block_fp4_mmq *) vy;
|
||||
block_fp4_mmq * yb = y + ib;
|
||||
|
||||
const int sub = (i0_base % QK_K) / QK_NVFP4_SUB;
|
||||
|
||||
float vals_raw[QK_NVFP4_SUB];
|
||||
float amax_raw = 0.0f;
|
||||
const int64_t base_idx = i3 * s03 + i2 * s02 + i01 * s01;
|
||||
#pragma unroll
|
||||
for (int k = 0; k < QK_NVFP4_SUB; k++) {
|
||||
const int64_t i00 = i0_base + k;
|
||||
if (i00 < ne00) {
|
||||
const float v = x[base_idx + i00];
|
||||
vals_raw[k] = v;
|
||||
amax_raw = fmaxf(amax_raw, fabsf(v));
|
||||
} else {
|
||||
vals_raw[k] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr int test_offsets[5] = { 0, -1, 1, -2, 2};
|
||||
const int first_fp8_code = (int) ggml_cuda_fp32_to_ue4m3(amax_raw / 6.0f);
|
||||
|
||||
float best_err = FLT_MAX;
|
||||
uint8_t fp8_code = 0;
|
||||
float subblock_scale = 0.0f;
|
||||
|
||||
#pragma unroll // Check +/- 2 to find best code to reduce NVFP4 activation loss. Negligible overhead on Blackwell.
|
||||
for (int i = 0; i < 5; i++) {
|
||||
const int test_code = first_fp8_code + test_offsets[i];
|
||||
if (test_code < 0 || test_code > 0x7e) {
|
||||
continue;
|
||||
}
|
||||
const uint8_t code = (uint8_t) test_code;
|
||||
const float test_scale = ggml_cuda_ue4m3_to_fp32(code);
|
||||
const float test_inv_scale = test_scale > 0.0f ? 0.5f / test_scale : 0.0f;
|
||||
float cur_err = 0.0f;
|
||||
#pragma unroll
|
||||
for (int k = 0; k < QK_NVFP4_SUB; ++k) {
|
||||
const float v = vals_raw[k];
|
||||
const uint8_t q = ggml_cuda_float_to_fp4_e2m1(v, test_inv_scale);
|
||||
const float err_diff = fabsf(v) - fabsf(kvalues_mxfp4[q & 0x7]) * test_scale;
|
||||
cur_err = fmaf(err_diff, err_diff, cur_err);
|
||||
}
|
||||
|
||||
if (cur_err < best_err) {
|
||||
best_err = cur_err;
|
||||
fp8_code = test_code;
|
||||
subblock_scale = test_scale;
|
||||
}
|
||||
}
|
||||
|
||||
const float inv_scale = subblock_scale > 0.0f ? 0.5f / subblock_scale : 0.0f;
|
||||
uint32_t q0 = 0;
|
||||
uint32_t q1 = 0;
|
||||
#pragma unroll // this is faster than the previous __nv_fp4x4_e2m1
|
||||
for (int k = 0; k < QK_NVFP4_SUB / 4; ++k) {
|
||||
q0 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 0], inv_scale) << (8 * k);
|
||||
q0 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 8], inv_scale) << (8 * k + 4);
|
||||
q1 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 4], inv_scale) << (8 * k);
|
||||
q1 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 12], inv_scale) << (8 * k + 4);
|
||||
}
|
||||
|
||||
uint32_t * yqs = reinterpret_cast<uint32_t *>(yb->qs);
|
||||
yqs[2 * sub + 0] = q0;
|
||||
yqs[2 * sub + 1] = q1;
|
||||
reinterpret_cast<uint8_t *>(yb->d4)[sub] = fp8_code;
|
||||
#else
|
||||
NO_DEVICE_CODE; // This is for Blackwell NVFP4 activations only.
|
||||
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
||||
|
||||
}
|
||||
|
||||
// quantize values in the format mxfp4 is stored which is interleaved nibbles
|
||||
// i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31
|
||||
static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
|
||||
@@ -316,28 +412,32 @@ void quantize_mmq_q8_1_cuda(
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_mmq_mxfp4_cuda(const float * x,
|
||||
const int32_t * ids,
|
||||
void * vy,
|
||||
[[maybe_unused]] const ggml_type type_src0,
|
||||
const int64_t ne00,
|
||||
const int64_t s01,
|
||||
const int64_t s02,
|
||||
const int64_t s03,
|
||||
const int64_t ne0,
|
||||
const int64_t ne1,
|
||||
const int64_t ne2,
|
||||
const int64_t ne3,
|
||||
cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
|
||||
void quantize_mmq_fp4_cuda(
|
||||
const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
|
||||
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
|
||||
GGML_ASSERT(type_src0 == GGML_TYPE_MXFP4 || type_src0 == GGML_TYPE_NVFP4);
|
||||
GGML_ASSERT(ne0 > 0);
|
||||
|
||||
constexpr int nwarps = 8;
|
||||
constexpr int vals_per_warp = 2 * QK_MXFP4;
|
||||
constexpr int vals_per_block = nwarps * vals_per_warp;
|
||||
if (type_src0 == GGML_TYPE_NVFP4) {
|
||||
GGML_ASSERT(ne00 % QK_NVFP4 == 0);
|
||||
constexpr int nvfp4_block_size = 128;
|
||||
const int64_t block_num_y = (ne0 + QK_NVFP4_SUB * nvfp4_block_size - 1) / (QK_NVFP4_SUB * nvfp4_block_size);
|
||||
const dim3 block_size(nvfp4_block_size, 1, 1);
|
||||
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
|
||||
quantize_mmq_nvfp4<<<num_blocks, block_size, 0, stream>>>(
|
||||
x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
|
||||
} else {
|
||||
GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
|
||||
|
||||
const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
|
||||
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
|
||||
const dim3 block_size(WARP_SIZE, nwarps, 1);
|
||||
constexpr int nwarps = 8;
|
||||
constexpr int vals_per_warp = 2 * QK_MXFP4;
|
||||
constexpr int vals_per_block = nwarps * vals_per_warp;
|
||||
|
||||
quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
|
||||
const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
|
||||
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
|
||||
const dim3 block_size(WARP_SIZE, nwarps, 1);
|
||||
|
||||
quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ void quantize_mmq_q8_1_cuda(
|
||||
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
|
||||
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
|
||||
|
||||
void quantize_mmq_mxfp4_cuda(const float * x,
|
||||
void quantize_mmq_fp4_cuda(const float * x,
|
||||
const int32_t * ids,
|
||||
void * vy,
|
||||
ggml_type type_src0,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
template <bool apply_silu, size_t split_d_inner, size_t d_conv>
|
||||
static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
|
||||
const float * __restrict__ bias,
|
||||
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
|
||||
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
|
||||
const int64_t n_t) {
|
||||
@@ -27,6 +28,8 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
|
||||
w[j] = w_block[tid * stride_w + j];
|
||||
}
|
||||
|
||||
float b = bias != nullptr ? bias[bidy * split_d_inner + tid] : 0.0f;
|
||||
|
||||
for (int64_t i = 0; i < n_t; i++) {
|
||||
float sumf = 0.0f;
|
||||
|
||||
@@ -42,12 +45,14 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
|
||||
for (size_t j = 0; j < d_conv; j++) {
|
||||
sumf += x[(i + j) % d_conv] * w[j];
|
||||
}
|
||||
sumf += b;
|
||||
y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool apply_silu, size_t split_d_inner, size_t d_conv, int64_t split_n_t>
|
||||
static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
|
||||
const float * __restrict__ bias,
|
||||
const int src0_nb0, const int src0_nb1, const int src0_nb2,
|
||||
const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
|
||||
const int dst_nb1, const int dst_nb2, const int64_t n_t) {
|
||||
@@ -97,6 +102,8 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
|
||||
w[j] = w_block[tid * stride_w + j];
|
||||
}
|
||||
|
||||
float b = bias != nullptr ? bias[bidy * split_d_inner + tid] : 0.0f;
|
||||
|
||||
// Compute from shared memory
|
||||
for (int64_t i = 0; i < local_n_t; i++) {
|
||||
float sumf = 0.0f;
|
||||
@@ -104,12 +111,13 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
|
||||
for (size_t j = 0; j < d_conv; j++) {
|
||||
sumf += smem[tid * n_cols + i + j] * w[j];
|
||||
}
|
||||
sumf += b;
|
||||
y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool apply_silu>
|
||||
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
|
||||
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const float * bias, const int src0_nb0, const int src0_nb1,
|
||||
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
|
||||
const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t,
|
||||
const int64_t n_s, cudaStream_t stream) {
|
||||
@@ -120,14 +128,14 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
|
||||
constexpr int kNC = decltype(NC)::value;
|
||||
if (n_t <= 32) {
|
||||
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
|
||||
ssm_conv_f32<apply_silu, threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
||||
ssm_conv_f32<apply_silu, threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
||||
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else {
|
||||
const int64_t split_n_t = 32;
|
||||
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
|
||||
const size_t smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float);
|
||||
ssm_conv_long_token_f32<apply_silu, threads, kNC, split_n_t><<<blocks, threads, smem_size, stream>>>(
|
||||
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -140,11 +148,18 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) {
|
||||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node, ggml_tensor * silu_dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
|
||||
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
|
||||
const bool fuse_bias = bias_add_node != nullptr;
|
||||
const bool fuse_silu = silu_dst != nullptr;
|
||||
|
||||
// bias always comes with silu.
|
||||
GGML_ASSERT(!fuse_bias || fuse_silu);
|
||||
|
||||
// The bias (when fused) is the non-conv operand of the ADD node.
|
||||
const struct ggml_tensor * bias = fuse_bias ? (bias_add_node->src[0] == dst ? bias_add_node->src[1] : bias_add_node->src[0]) : nullptr;
|
||||
|
||||
// When fusing, write to silu_dst (the node downstream references).
|
||||
const struct ggml_tensor * out = fuse_silu ? silu_dst : dst;
|
||||
|
||||
@@ -160,16 +175,23 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, g
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
const float * bias_d = fuse_bias ? (const float *) bias->data : nullptr;
|
||||
float * dst_d = (float *) out->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(out->type == GGML_TYPE_F32);
|
||||
if (fuse_bias) {
|
||||
GGML_ASSERT(bias->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(bias));
|
||||
GGML_ASSERT(ggml_nelements(bias) == nr);
|
||||
}
|
||||
|
||||
if (fuse_silu) {
|
||||
ssm_conv_f32_cuda<true>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
|
||||
ssm_conv_f32_cuda<true>(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
|
||||
out->nb[2], nc, nr, n_t, n_s, stream);
|
||||
} else {
|
||||
ssm_conv_f32_cuda<false>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
|
||||
ssm_conv_f32_cuda<false>(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
|
||||
out->nb[2], nc, nr, n_t, n_s, stream);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst = nullptr);
|
||||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node = nullptr, ggml_tensor * silu_dst = nullptr);
|
||||
|
||||
@@ -2,4 +2,5 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
|
||||
|
||||
@@ -2,4 +2,5 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(320, 256);
|
||||
@@ -3,7 +3,7 @@
|
||||
from glob import glob
|
||||
import os
|
||||
|
||||
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 512, 576]
|
||||
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 320, 512, 576]
|
||||
|
||||
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"]
|
||||
|
||||
@@ -62,7 +62,7 @@ for filename in glob("*.cu"):
|
||||
os.remove(filename)
|
||||
|
||||
for head_size_kq in HEAD_SIZES_KQ:
|
||||
head_size_v = head_size_kq if head_size_kq != 576 else 512
|
||||
head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512)
|
||||
with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v))
|
||||
|
||||
@@ -84,13 +84,16 @@ for ncols in [8, 16, 32, 64]:
|
||||
continue
|
||||
if head_size_kq == 72:
|
||||
continue
|
||||
if head_size_kq == 512 and ncols2 not in (4, 8):
|
||||
# Skip compilation of unused ncols2 values for niche head sizes:
|
||||
if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4
|
||||
continue
|
||||
if head_size_kq != 576 and ncols2 in (16, 32):
|
||||
if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4
|
||||
continue
|
||||
if head_size_kq == 576 and ncols2 not in (4, 16, 32):
|
||||
if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash
|
||||
continue
|
||||
head_size_v = head_size_kq if head_size_kq != 576 else 512
|
||||
if head_size_kq not in (320, 576) and ncols2 in (16, 32):
|
||||
continue
|
||||
head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512)
|
||||
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
|
||||
|
||||
for type in TYPES_MMQ:
|
||||
|
||||
1
ggml/src/ggml-cuda/vendors/hip.h
vendored
1
ggml/src/ggml-cuda/vendors/hip.h
vendored
@@ -55,6 +55,7 @@
|
||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
#define cudaDeviceGetAttribute hipDeviceGetAttribute
|
||||
#define cudaDeviceGetPCIBusId hipDeviceGetPCIBusId
|
||||
#define cudaDeviceProp hipDeviceProp_t
|
||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||
#define cudaError_t hipError_t
|
||||
|
||||
1
ggml/src/ggml-cuda/vendors/musa.h
vendored
1
ggml/src/ggml-cuda/vendors/musa.h
vendored
@@ -39,6 +39,7 @@
|
||||
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
|
||||
#define cudaDeviceGetPCIBusId musaDeviceGetPCIBusId
|
||||
#define cudaDeviceProp musaDeviceProp
|
||||
#define cudaDeviceSynchronize musaDeviceSynchronize
|
||||
#define cudaError_t musaError_t
|
||||
|
||||
@@ -22,7 +22,8 @@ message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for
|
||||
include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
|
||||
include(ExternalProject)
|
||||
|
||||
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
|
||||
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
|
||||
option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF)
|
||||
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
|
||||
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)")
|
||||
|
||||
|
||||
@@ -48,14 +48,16 @@ using intvec = std::vector<int>;
|
||||
using uintvec = std::vector<unsigned int>;
|
||||
using u32vec = std::vector<uint32_t>;
|
||||
|
||||
static size_t opt_ndev = 1;
|
||||
static size_t opt_nhvx = 0; // use all
|
||||
static int opt_arch = 0; // autodetect
|
||||
static int opt_etm = 0;
|
||||
static int opt_verbose = 0;
|
||||
static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu)
|
||||
static int opt_hostbuf = 1; // hostbuf ON by default
|
||||
static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only
|
||||
static int opt_arch = 0; // autodetect
|
||||
static size_t opt_ndev = 1;
|
||||
static size_t opt_nhvx = 0; // use all
|
||||
static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only
|
||||
static size_t opt_vmem = HTP_OP_MAX_VMEM_DEFAULT; // max available va space for buffer mappings
|
||||
static size_t opt_mbuf = 1ul * 1024 * 1024 * 1024; // max buffer size
|
||||
static int opt_etm = 0;
|
||||
static int opt_verbose = 0;
|
||||
static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu)
|
||||
static int opt_hostbuf = 1; // hostbuf ON by default
|
||||
|
||||
// Default PMU events, if profiling with PMU (mode=2) is enabled
|
||||
// See https://docs.qualcomm.com/doc/80-N2040-60/topic/pmu-events.html
|
||||
@@ -66,6 +68,7 @@ static u32vec opt_pmu_evt { 0x3, 0x111, 0x100, 0x105, 0x240, 0x256, 0x7D, 0x8C }
|
||||
static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE;
|
||||
static int opt_opbatch = 1024; // max number of ops in a batch
|
||||
static int opt_opqueue = 16; // max number of pending batches
|
||||
|
||||
static std::regex* opt_opfilter = NULL; // regex of ops to not claim
|
||||
|
||||
#define HEX_VERBOSE(...) \
|
||||
@@ -110,7 +113,7 @@ static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct
|
||||
if (!opt_verbose) return;
|
||||
|
||||
op_desc desc(op);
|
||||
GGML_LOG_DEBUG("ggml-hex: %s supports-op %s : %s : %s : %s : %s : %s : %s\n", sess_name.c_str(),
|
||||
GGML_LOG_DEBUG("ggml-hex: %s supports-op %s: %s : %s : %s : %s : %s : %s\n", sess_name.c_str(),
|
||||
ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? "yes" : "no");
|
||||
}
|
||||
|
||||
@@ -118,8 +121,6 @@ static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_t
|
||||
uint32_t op_usec, uint32_t op_cycles, const uint32_t pmu[]) {
|
||||
if (!opt_profile) return;
|
||||
|
||||
op_desc desc(op);
|
||||
|
||||
char pmu_str[256] = "";
|
||||
if (opt_profile > 1) {
|
||||
static_assert(HTP_PROF_PMU_NCNT == 8, "current implementation assumes 8 PMU counters");
|
||||
@@ -127,6 +128,7 @@ static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_t
|
||||
pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]);
|
||||
}
|
||||
|
||||
op_desc desc(op);
|
||||
GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u%s\n", sess_name.c_str(),
|
||||
ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, op_usec, op_cycles, pmu_str);
|
||||
}
|
||||
@@ -191,33 +193,30 @@ struct ggml_hexagon_shared_buffer {
|
||||
bool mapped;
|
||||
bool pinned;
|
||||
|
||||
void mmap(bool pinned = false) {
|
||||
int err = fastrpc_mmap(sess->domain_id, this->fd, (void *) this->base, 0, this->size, FASTRPC_MAP_FD_DELAYED);
|
||||
void mmap() {
|
||||
fastrpc_map_flags flags = this->pinned ? FASTRPC_MAP_FD : FASTRPC_MAP_FD_DELAYED;
|
||||
|
||||
int err = fastrpc_mmap(sess->domain_id, this->fd, (void *) this->base, 0, this->size, flags);
|
||||
if (err != 0) {
|
||||
GGML_LOG_ERROR("ggml-hex: %s buffer mapping failed : domain_id %d size %zu fd %d error 0x%08x\n", sess->c_name(),
|
||||
sess->domain_id, this->size, this->fd, (unsigned) err);
|
||||
throw std::runtime_error("ggml-hex: fastrpc_mmap failed (see log for details)");
|
||||
}
|
||||
|
||||
if (pinned) {
|
||||
err = htp_iface_mmap(sess->handle, this->fd, this->size, pinned);
|
||||
if (err != 0) {
|
||||
GGML_LOG_ERROR("ggml-hex: %s buffer pinning failed : domain_id %d size %zu fd %d error 0x%08x\n", sess->c_name(),
|
||||
sess->domain_id, this->size, this->fd, (unsigned) err);
|
||||
throw std::runtime_error("ggml-hex: htp_iface_mmap failed (see log for details)");
|
||||
}
|
||||
}
|
||||
|
||||
this->mapped = true;
|
||||
this->pinned = pinned;
|
||||
HEX_VERBOSE("ggml-hex: %s mapped buffer: base %p size %zu fd %d pinned %u\n",
|
||||
sess->c_name(), (void *) this->base, this->size, this->fd, pinned);
|
||||
|
||||
this->mapped = true;
|
||||
}
|
||||
|
||||
void unmap() {
|
||||
if (!this->mapped) return;
|
||||
|
||||
htp_iface_munmap(sess->handle, this->fd);
|
||||
if (!this->pinned) {
|
||||
// HTP might still hold a reference, tell it drop it
|
||||
htp_iface_munmap(sess->handle, this->fd);
|
||||
}
|
||||
|
||||
fastrpc_munmap(sess->domain_id, this->fd, (void *) this->base, this->size);
|
||||
|
||||
HEX_VERBOSE("ggml-hex: %s unmapped buffer: base %p size %zu fd %d\n", sess->c_name(),
|
||||
@@ -227,7 +226,7 @@ struct ggml_hexagon_shared_buffer {
|
||||
this->fd = -1;
|
||||
}
|
||||
|
||||
void alloc(size_t size, bool pinned = false) {
|
||||
void alloc(size_t size) {
|
||||
if (this->base) return;
|
||||
|
||||
this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS, size);
|
||||
@@ -245,8 +244,7 @@ struct ggml_hexagon_shared_buffer {
|
||||
|
||||
HEX_VERBOSE("ggml-hex: %s allocated buffer: base %p size %zu fd %d pinned %d\n", sess->c_name(),
|
||||
(void *) this->base, this->size, this->fd, (int) pinned);
|
||||
|
||||
mmap(pinned);
|
||||
mmap();
|
||||
}
|
||||
|
||||
void free() {
|
||||
@@ -262,15 +260,14 @@ struct ggml_hexagon_shared_buffer {
|
||||
}
|
||||
|
||||
ggml_hexagon_shared_buffer(ggml_hexagon_session * sess, size_t size, bool pinned = false) {
|
||||
size += 4 * 1024; // extra page for padding
|
||||
|
||||
this->sess = sess;
|
||||
this->size = 0;
|
||||
this->base = nullptr;
|
||||
this->fd = -1;
|
||||
this->mapped = false;
|
||||
this->pinned = pinned;
|
||||
|
||||
alloc(size, pinned);
|
||||
alloc(size);
|
||||
}
|
||||
|
||||
~ggml_hexagon_shared_buffer() {
|
||||
@@ -1475,6 +1472,7 @@ static ggml_backend_buffer_t ggml_backend_hexagon_buffer_type_alloc_buffer(
|
||||
ggml_backend_buffer_type_t buffer_type, size_t size) {
|
||||
auto sess = static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->sess;
|
||||
try {
|
||||
size += 4 * 1024; // guard page
|
||||
ggml_hexagon_shared_buffer * sbuf = new ggml_hexagon_shared_buffer(sess, size);
|
||||
return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, sbuf, size);
|
||||
} catch (const std::exception & exc) {
|
||||
@@ -1487,6 +1485,7 @@ static ggml_backend_buffer_t ggml_backend_hexagon_repack_buffer_type_alloc_buffe
|
||||
ggml_backend_buffer_type_t buffer_type, size_t size) {
|
||||
auto sess = static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->sess;
|
||||
try {
|
||||
size += 4 * 1024; // guard page
|
||||
ggml_hexagon_shared_buffer * sbuf = new ggml_hexagon_shared_buffer(sess, size);
|
||||
return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, sbuf, size);
|
||||
} catch (const std::exception & exc) {
|
||||
@@ -1505,7 +1504,7 @@ static size_t ggml_backend_hexagon_buffer_type_get_alloc_size(ggml_backend_buffe
|
||||
}
|
||||
|
||||
static size_t ggml_backend_hexagon_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) {
|
||||
return 1UL * 1024 * 1024 * 1024; // 1GB per buffer
|
||||
return opt_mbuf; // typically 1GB per buffer
|
||||
GGML_UNUSED(buffer_type);
|
||||
}
|
||||
|
||||
@@ -1573,14 +1572,14 @@ struct ggml_hexagon_opbatch {
|
||||
d_map.clear();
|
||||
}
|
||||
|
||||
ggml_hexagon_opbatch(ggml_hexagon_session *sess, size_t batch_size) {
|
||||
ggml_hexagon_opbatch(ggml_hexagon_session *sess, size_t batch_size, size_t max_vmem) {
|
||||
this->sess = sess;
|
||||
|
||||
n_bufs_max = HTP_OP_MAX_BUFS;
|
||||
n_ops_max = batch_size;
|
||||
n_tens_max = n_ops_max + n_ops_max * HTP_OP_MAX_INPUTS;
|
||||
|
||||
b_vmem_max = HTP_OP_MAX_VMEM;
|
||||
b_vmem_max = max_vmem;
|
||||
|
||||
ops.resize(n_ops_max);
|
||||
|
||||
@@ -1592,6 +1591,9 @@ struct ggml_hexagon_opbatch {
|
||||
t_map.reserve(n_tens_max);
|
||||
d_map.reserve(n_tens_max);
|
||||
|
||||
GGML_LOG_INFO("ggml-hex: %s op batching: n-bufs %u n-tensors %u n-ops %u vmem %zu\n",
|
||||
sess->c_name(), n_bufs_max, n_tens_max, n_ops_max, b_vmem_max);
|
||||
|
||||
reset();
|
||||
}
|
||||
|
||||
@@ -1925,6 +1927,8 @@ void ggml_hexagon_session::flush_batch() {
|
||||
// Bump pending flag (cleared in the session::flush once we get the response)
|
||||
this->op_pending++; // atomic inc
|
||||
|
||||
HEX_VERBOSE("ggml-hex: %s queue-opbatch: %p size %u\n", this->c_name(), dbuf.ptr, dbuf.size);
|
||||
|
||||
int err = dspqueue_write(this->queue, 0, 1, &dbuf, sizeof(req), (const uint8_t*) &req, DSPQUEUE_TIMEOUT);
|
||||
if (err != 0) {
|
||||
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->c_name(), (unsigned) err);
|
||||
@@ -1944,6 +1948,35 @@ void ggml_hexagon_session::flush(bool all) {
|
||||
flush_pending(all);
|
||||
}
|
||||
|
||||
static size_t ggml_hexagon_measure_max_vmem(ggml_hexagon_session *sess) {
|
||||
// Allocate a bunch pinned buffers till failure.
|
||||
// This is kind of expensive but handy for figuring out exactly how much we can mmap on a specific device.
|
||||
// Typically we're going to allocate all/most of these buffers anyway for the model weights.
|
||||
|
||||
std::vector<ggml_hexagon_shared_buffer *> sbufs;
|
||||
|
||||
const size_t MiB = 1024 * 1024;
|
||||
const size_t GiB = MiB * 1024;
|
||||
|
||||
size_t vmem = 0;
|
||||
size_t step = 256u * MiB;
|
||||
|
||||
try {
|
||||
sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB;
|
||||
sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB;
|
||||
sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB;
|
||||
|
||||
while (1) {
|
||||
sbufs.push_back(new ggml_hexagon_shared_buffer(sess, step, true));
|
||||
vmem += step;
|
||||
}
|
||||
} catch (...) { }
|
||||
|
||||
for (auto b : sbufs) { delete b; }
|
||||
|
||||
return vmem - step; // backoff to account for overhead from internal mappings
|
||||
}
|
||||
|
||||
void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
|
||||
this->valid_session = false;
|
||||
this->valid_handle = false;
|
||||
@@ -1957,7 +1990,7 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
|
||||
|
||||
this->op_pending = 0;
|
||||
|
||||
GGML_LOG_INFO("ggml-hex: allocating new session: %s\n", this->name.c_str());
|
||||
GGML_LOG_DEBUG("ggml-hex: %s allocating new session\n", this->name.c_str());
|
||||
|
||||
domain * my_domain = get_domain(this->domain_id);
|
||||
if (my_domain == NULL) {
|
||||
@@ -2033,9 +2066,6 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
|
||||
|
||||
this->valid_handle = true;
|
||||
|
||||
GGML_LOG_INFO("ggml-hex: new session: %s : session-id %d domain-id %d uri %s handle 0x%lx\n", this->name.c_str(),
|
||||
this->session_id, this->domain_id, session_uri, (unsigned long) this->handle);
|
||||
|
||||
// Enable FastRPC QoS mode
|
||||
{
|
||||
struct remote_rpc_control_latency l;
|
||||
@@ -2047,6 +2077,9 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
|
||||
}
|
||||
}
|
||||
|
||||
GGML_LOG_INFO("ggml-hex: %s new session : session-id %d domain-id %d uri %s handle 0x%lx\n", this->c_name(),
|
||||
this->session_id, this->domain_id, session_uri, (unsigned long) this->handle);
|
||||
|
||||
const size_t req_q_size = (sizeof(htp_opbatch_req) * opt_opqueue * 2) + 1024;
|
||||
const size_t rsp_q_size = (sizeof(htp_opbatch_rsp) * opt_opqueue * 2) + 1024;
|
||||
|
||||
@@ -2091,13 +2124,19 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
|
||||
}
|
||||
|
||||
// Allocate buffers and state for op batching
|
||||
this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch);
|
||||
this->op_queue = new ggml_hexagon_opqueue(this, opt_opbatch, opt_opqueue);
|
||||
|
||||
// Start processing op batch requests
|
||||
err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx);
|
||||
if (!opt_vmem) {
|
||||
opt_vmem = ggml_hexagon_measure_max_vmem(this);
|
||||
GGML_LOG_INFO("ggml-hex: %s measured max vmem %zu\n", this->c_name(), opt_vmem);
|
||||
}
|
||||
|
||||
this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch, opt_vmem);
|
||||
|
||||
// Start dspqueue/opbatch processing
|
||||
err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx, opt_vmem);
|
||||
if (err != 0) {
|
||||
GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err);
|
||||
GGML_LOG_ERROR("ggml-hex: %s failed to start session: 0x%08x\n", this->c_name(), (unsigned) err);
|
||||
throw std::runtime_error("ggml-hex: iface start failed (see log for details)");
|
||||
}
|
||||
this->valid_iface = true;
|
||||
@@ -2108,17 +2147,17 @@ void ggml_hexagon_session::release() noexcept(true) {
|
||||
|
||||
int err;
|
||||
|
||||
delete this->op_batch;
|
||||
delete this->op_queue;
|
||||
|
||||
// Stop the DSP-side service and close the queue
|
||||
if (this->valid_iface) {
|
||||
// Stop dspqueue/opbatch processing
|
||||
err = htp_iface_stop(this->handle);
|
||||
if (err != 0) {
|
||||
GGML_ABORT("ggml-hex: htp_iface_stop failed: 0x%08x\n", (unsigned) err);
|
||||
}
|
||||
}
|
||||
|
||||
delete this->op_batch;
|
||||
delete this->op_queue;
|
||||
|
||||
if (opt_etm) {
|
||||
err = htp_iface_etm(this->handle, 0);
|
||||
if (err != 0) {
|
||||
@@ -2215,8 +2254,7 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->ne[2] != 1 || dst->ne[3] != 1) {
|
||||
// FA during prompt still needs work
|
||||
if (dst->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -2382,8 +2420,8 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: add support for non-contigiuos tensors
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
|
||||
// TODO: add support for non-contiguous elements within a row
|
||||
if (!ggml_is_contiguous_rows(src0) || !ggml_is_contiguous_rows(dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -2997,8 +3035,8 @@ static struct ggml_backend_i hexagon_backend_i = {
|
||||
/* .free = */ ggml_backend_hexagon_free,
|
||||
/* .set_tensor_async = */ NULL,
|
||||
/* .get_tensor_async = */ NULL,
|
||||
/* .get_tensor_2d_async = */ NULL,
|
||||
/* .set_tensor_2d_async = */ NULL,
|
||||
/* .get_tensor_2d_async = */ NULL,
|
||||
/* .cpy_tensor_async = */ NULL,
|
||||
/* .synchronize = */ ggml_backend_hexagon_synchronize,
|
||||
/* .graph_plan_create = */ NULL,
|
||||
@@ -3380,21 +3418,6 @@ struct ggml_hexagon_registry {
|
||||
ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) {
|
||||
GGML_LOG_INFO("ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev %zu\n", opt_ndev);
|
||||
|
||||
if (!opt_arch) {
|
||||
int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch);
|
||||
if (err != 0) {
|
||||
GGML_LOG_ERROR("ggml-hex: failed to query HTP version (err %d) defaulting to v73\n", err);
|
||||
opt_arch = 73;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
if (opt_arch < 75) {
|
||||
opt_ndev = 1;
|
||||
GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch);
|
||||
|
||||
// Create devices / sessions
|
||||
@@ -3480,32 +3503,67 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
static_assert((unsigned int) HTP_TYPE_IQ4_NL == (unsigned int) GGML_TYPE_IQ4_NL,
|
||||
"please update hexagon_type to match ggml_type");
|
||||
|
||||
const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE");
|
||||
const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF");
|
||||
const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE");
|
||||
const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH");
|
||||
const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE");
|
||||
const char * str_opfilter= getenv("GGML_HEXAGON_OPFILTER");
|
||||
const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
|
||||
const char * str_etm = getenv("GGML_HEXAGON_ETM");
|
||||
const char * str_nhvx = getenv("GGML_HEXAGON_NHVX");
|
||||
const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX");
|
||||
const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
|
||||
const char * str_arch = getenv("GGML_HEXAGON_ARCH");
|
||||
const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE");
|
||||
const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF");
|
||||
const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE");
|
||||
const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH");
|
||||
const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE");
|
||||
const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER");
|
||||
const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
|
||||
const char * str_etm = getenv("GGML_HEXAGON_ETM");
|
||||
const char * str_nhvx = getenv("GGML_HEXAGON_NHVX");
|
||||
const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX");
|
||||
const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
|
||||
const char * str_arch = getenv("GGML_HEXAGON_ARCH");
|
||||
const char * str_vmem = getenv("GGML_HEXAGON_VMEM");
|
||||
const char * str_mbuf = getenv("GGML_HEXAGON_MBUF");
|
||||
|
||||
// Init Arch first since it affects other defaults
|
||||
if (!str_arch) {
|
||||
int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch);
|
||||
if (err != 0) {
|
||||
GGML_LOG_ERROR("ggml-hex: failed to query HTP version (err %d) defaulting to v73\n", err);
|
||||
opt_arch = 73;
|
||||
}
|
||||
} else {
|
||||
if (str_arch[0] == 'v' || str_arch[0] == 'V') {
|
||||
str_arch++;
|
||||
}
|
||||
opt_arch = strtoul(str_arch, NULL, 0);
|
||||
}
|
||||
|
||||
size_t MiB = 1024 * 1024;
|
||||
|
||||
// Update vmem default
|
||||
opt_vmem = opt_arch >= 75 ? HTP_OP_MAX_VMEM_DEFAULT : 3000 * MiB;
|
||||
|
||||
auto RE_ICASE = std::regex_constants::icase;
|
||||
|
||||
opt_opfilter = str_opfilter ? new std::regex(str_opfilter, RE_ICASE) : NULL;
|
||||
opt_verbose = str_verbose ? atoi(str_verbose) : 0;
|
||||
opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf;
|
||||
opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage;
|
||||
opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch;
|
||||
opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue;
|
||||
opt_etm = str_etm ? atoi(str_etm) : 0;
|
||||
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
|
||||
opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx;
|
||||
opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev;
|
||||
opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf;
|
||||
opt_opfilter = str_opfilter ? new std::regex(str_opfilter, RE_ICASE) : NULL;
|
||||
opt_verbose = str_verbose ? atoi(str_verbose) : 0;
|
||||
opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf;
|
||||
opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage;
|
||||
opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch;
|
||||
opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue;
|
||||
opt_profile = str_profile ? atoi(str_profile) : 0;
|
||||
opt_etm = str_etm ? atoi(str_etm) : 0;
|
||||
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
|
||||
opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx;
|
||||
opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev;
|
||||
opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf;
|
||||
opt_mbuf = str_mbuf ? strtoul(str_mbuf, NULL, 0) * MiB : opt_mbuf;
|
||||
opt_vmem = str_vmem ? strtoul(str_vmem, NULL, 0) * MiB : opt_vmem;
|
||||
|
||||
if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) {
|
||||
opt_ndev = GGML_HEXAGON_MAX_SESSIONS;
|
||||
}
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
if (opt_arch < 75) {
|
||||
opt_ndev = 1;
|
||||
GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
if (str_profile) {
|
||||
opt_pmu_evt = [&]() -> std::vector<uint32_t> {
|
||||
@@ -3520,17 +3578,6 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
vec_to_str<uint32_t, 16>(opt_pmu_evt).c_str());
|
||||
}
|
||||
|
||||
if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) {
|
||||
opt_ndev = GGML_HEXAGON_MAX_SESSIONS;
|
||||
}
|
||||
|
||||
if (str_arch) {
|
||||
if (str_arch[0] == 'v') {
|
||||
str_arch++;
|
||||
}
|
||||
opt_arch = strtoul(str_arch, NULL, 0);
|
||||
}
|
||||
|
||||
reg->context = new ggml_hexagon_registry(reg);
|
||||
}
|
||||
|
||||
|
||||
@@ -44,6 +44,11 @@ target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
|
||||
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
|
||||
|
||||
if (GGML_HEXAGON_FA_EXP2_HF)
|
||||
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
|
||||
endif()
|
||||
|
||||
# HMX acceleration: available on v73+ architectures
|
||||
set(HTP_HMX_VERSIONS v73 v75 v79 v81)
|
||||
list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
|
||||
@@ -52,11 +57,13 @@ if (_hmx_idx GREATER_EQUAL 0)
|
||||
target_sources(${HTP_LIB} PRIVATE
|
||||
hmx-queue.c
|
||||
hmx-matmul-ops.c
|
||||
hmx-flash-attn-ops.c
|
||||
)
|
||||
|
||||
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
|
||||
set_source_files_properties(
|
||||
hmx-matmul-ops.c
|
||||
hmx-flash-attn-ops.c
|
||||
PROPERTIES COMPILE_OPTIONS "-mhmx"
|
||||
)
|
||||
|
||||
|
||||
@@ -138,15 +138,15 @@ set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,")
|
||||
set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,")
|
||||
|
||||
#Compiler Options
|
||||
set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
|
||||
set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
|
||||
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O3")
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O2")
|
||||
|
||||
set(CMAKE_C_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
|
||||
set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g")
|
||||
set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O3")
|
||||
set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g")
|
||||
set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O2")
|
||||
|
||||
set(CMAKE_ASM_FLAGS_DEBUG "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG}")
|
||||
set(CMAKE_ASM_FLAGS_RELEASE "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE}")
|
||||
|
||||
@@ -17,13 +17,14 @@
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hmx-ops.h"
|
||||
|
||||
// Must be multiple of 32
|
||||
#define FLASH_ATTN_BLOCK_SIZE (32 * 2)
|
||||
|
||||
// This is a bit of a hack because the compiler is strugling to properly inline
|
||||
// the default hvx_vec_f32_to_f16 with output into the local array.
|
||||
static void __attribute__((noinline)) hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
|
||||
static __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
|
||||
{
|
||||
*(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1);
|
||||
}
|
||||
@@ -621,6 +622,17 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
// HMX path: prefill (neq1 >= 32), head_dim multiple of 32, F16 KV
|
||||
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0 && q->ne[1] >= 32) {
|
||||
int ret = hmx_flash_attn_ext(octx);
|
||||
if (ret == HTP_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
// VTCM too small or other failure -> fall through to HVX path
|
||||
}
|
||||
#endif
|
||||
|
||||
struct htp_fa_context factx;
|
||||
factx.octx = octx;
|
||||
|
||||
|
||||
@@ -74,6 +74,12 @@ static inline size_t hex_smax(size_t a, size_t b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
static inline void hex_swap_ptr(void ** p1, void ** p2) {
|
||||
void * t = *p1;
|
||||
*p1 = *p2;
|
||||
*p2 = t;
|
||||
}
|
||||
|
||||
static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
|
||||
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
|
||||
Q6_l2fetch_AP((void *) p, control);
|
||||
|
||||
1840
ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c
Normal file
1840
ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -61,6 +61,9 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx,
|
||||
int m, int k, int n,
|
||||
int weight_type);
|
||||
|
||||
// HMX flash attention
|
||||
int hmx_flash_attn_ext(struct htp_ops_context * octx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -4,6 +4,9 @@
|
||||
#ifndef HMX_UTILS_H
|
||||
#define HMX_UTILS_H
|
||||
|
||||
#include "hvx-base.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <stddef.h>
|
||||
|
||||
@@ -12,21 +15,188 @@
|
||||
#define HMX_FP16_TILE_N_ELMS 1024
|
||||
#define HMX_FP16_TILE_SIZE 2048
|
||||
|
||||
#define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline))
|
||||
|
||||
// Initialise aligned 256-byte area with scale vector + zero padding.
|
||||
static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) {
|
||||
HVX_Vector *pv = (HVX_Vector *)out_scales;
|
||||
*pv++ = v_scale;
|
||||
*pv = Q6_V_vzero();
|
||||
static inline void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) {
|
||||
volatile HVX_Vector *pv = (HVX_Vector *) out_scales;
|
||||
pv[0] = v_scale;
|
||||
pv[1] = Q6_V_vzero();
|
||||
}
|
||||
|
||||
// --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) ---
|
||||
// --- Shared scatter offsets and interleave helper ---
|
||||
|
||||
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {
|
||||
uint8_t *p = *vtcm_ptr;
|
||||
*vtcm_ptr += size;
|
||||
return p;
|
||||
// vscatter offsets for fused dequant+transpose: write K-values directly to [K][N] tile.
|
||||
// word[i] = i*128 maps K-row-pair i to byte offset i*128.
|
||||
// Column offset (n*4) is added at runtime. Entries 0..15 cover one tile (region 2047);
|
||||
// entries 16..31 cover the next adjacent tile (region 4095) — pick region size at the
|
||||
// call site to scatter into one tile (masked) or two contiguous tiles (unmasked).
|
||||
static const int32_t hmx_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = {
|
||||
0 * 128, 1 * 128, 2 * 128, 3 * 128, 4 * 128, 5 * 128, 6 * 128, 7 * 128, 8 * 128, 9 * 128, 10 * 128,
|
||||
11 * 128, 12 * 128, 13 * 128, 14 * 128, 15 * 128, 16 * 128, 17 * 128, 18 * 128, 19 * 128, 20 * 128, 21 * 128,
|
||||
22 * 128, 23 * 128, 24 * 128, 25 * 128, 26 * 128, 27 * 128, 28 * 128, 29 * 128, 30 * 128, 31 * 128,
|
||||
};
|
||||
|
||||
// Scatter row-major FP16 data (in VTCM scratch) into transposed [K][N] tiles.
|
||||
// vtcm_src: [n_cols][src_stride] row-major fp16 (only first k elements per row are used)
|
||||
// vtcm_dst: [n_col_tiles][n_k_tiles][HMX_FP16_TILE_N_ELMS] tile-major interleaved fp16
|
||||
// Processes rows [start_row, end_row) for multi-thread slicing.
|
||||
// Full range: start_row=0, end_row=n_cols.
|
||||
static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
const __fp16 * restrict vtcm_src,
|
||||
int n_cols,
|
||||
int k,
|
||||
int src_stride,
|
||||
int start_row,
|
||||
int end_row) {
|
||||
assert(k % HMX_FP16_TILE_N_COLS == 0);
|
||||
|
||||
const int n_k_tiles = k / HMX_FP16_TILE_N_COLS;
|
||||
const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets);
|
||||
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4);
|
||||
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64);
|
||||
// Each hvx_vmemu load brings 64 fp16 = 128 bytes covering 2 adjacent K-tiles.
|
||||
// When n_k_tiles is even, scatter into 2 K-tiles per call (region 4095, no mask)
|
||||
// using the upper half of hmx_transpose_scatter_offsets. Tail one K-tile (when
|
||||
// n_k_tiles is odd) falls back to single-tile masked scatter.
|
||||
const bool pair_scatter = (n_k_tiles & 1) == 0;
|
||||
const size_t pair_region = (size_t) (2 * HMX_FP16_TILE_SIZE - 1);
|
||||
const size_t single_region = (size_t) (HMX_FP16_TILE_SIZE - 1);
|
||||
__builtin_assume(k > 0);
|
||||
__builtin_assume(end_row > start_row);
|
||||
|
||||
if (pair_scatter) {
|
||||
// Step c by 64 fp16 (two K-tiles per scatter), advance dst by 2 tiles per iter.
|
||||
const int c_step = 2 * HMX_FP16_TILE_N_COLS;
|
||||
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
|
||||
const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS;
|
||||
const int n_c_iters = k / c_step;
|
||||
|
||||
for (int r = start_row; r < end_row; r += 2) {
|
||||
const int ct = r / HMX_FP16_TILE_N_ROWS;
|
||||
const int local_r = r % HMX_FP16_TILE_N_ROWS;
|
||||
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
|
||||
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
|
||||
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
|
||||
|
||||
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
|
||||
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
|
||||
|
||||
if (p1) {
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0);
|
||||
p0 += c_byte_step;
|
||||
HVX_Vector v1 = hvx_vmemu(p1);
|
||||
p1 += c_byte_step;
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, v1);
|
||||
tile_base += dst_step;
|
||||
}
|
||||
} else {
|
||||
const HVX_Vector vzero = Q6_V_vzero();
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0);
|
||||
p0 += c_byte_step;
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero);
|
||||
tile_base += dst_step;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback: scatter one K-tile per call (region 2047, masked).
|
||||
const int c_step = HMX_FP16_TILE_N_COLS;
|
||||
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
|
||||
const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS;
|
||||
const int n_c_iters = k / c_step;
|
||||
|
||||
for (int r = start_row; r < end_row; r += 2) {
|
||||
const int ct = r / HMX_FP16_TILE_N_ROWS;
|
||||
const int local_r = r % HMX_FP16_TILE_N_ROWS;
|
||||
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
|
||||
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
|
||||
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
|
||||
|
||||
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
|
||||
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
|
||||
|
||||
if (p1) {
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0);
|
||||
p0 += c_byte_step;
|
||||
HVX_Vector v1 = hvx_vmemu(p1);
|
||||
p1 += c_byte_step;
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, v1);
|
||||
tile_base += dst_step;
|
||||
}
|
||||
} else {
|
||||
const HVX_Vector vzero = Q6_V_vzero();
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0);
|
||||
p0 += c_byte_step;
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero);
|
||||
tile_base += dst_step;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Interleave row-major FP16 data into column-major tile format.
|
||||
// Input: [n_rows, head_dim] row-major. Output: tile[dim_tile][row_tile].
|
||||
// Processes rows [start_row, end_row) for multi-thread slicing.
|
||||
// Full range: start_row=0, end_row=n_rows.
|
||||
static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
|
||||
const __fp16 * restrict src,
|
||||
int n_rows,
|
||||
int head_dim,
|
||||
int src_stride,
|
||||
int n_row_tiles,
|
||||
int start_row,
|
||||
int end_row) {
|
||||
__builtin_assume(head_dim > 0);
|
||||
const size_t tile_stride_elms = (size_t) n_row_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
|
||||
for (int r = start_row; r < end_row; r += 2) {
|
||||
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_rows;
|
||||
|
||||
const HVX_Vector * pv_in0 = (const HVX_Vector *) (src + r * src_stride);
|
||||
const HVX_Vector * pv_in1 = next_row_valid ? (const HVX_Vector *) (src + (r + 1) * src_stride) : NULL;
|
||||
|
||||
// Row-pair invariants hoisted out of the c loop.
|
||||
const int r0 = r / HMX_FP16_TILE_N_ROWS;
|
||||
const int r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2;
|
||||
|
||||
// tb0 starts at tile (c0=0, r0); tb1 at the adjacent dim-tile (c0=1, r0).
|
||||
// Each c step (+= 64) advances both by 2 dim-tiles worth of fp16.
|
||||
__fp16 * tb0 = tiles_out + (size_t) r0 * HMX_FP16_TILE_N_ELMS;
|
||||
__fp16 * tb1 = tb0 + tile_stride_elms;
|
||||
const size_t tb_step = 2 * tile_stride_elms;
|
||||
|
||||
if (pv_in1) {
|
||||
for (int c = 0; c < head_dim; c += 64) {
|
||||
HVX_Vector v0 = *pv_in0++;
|
||||
HVX_Vector v1 = *pv_in1++;
|
||||
HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2);
|
||||
((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp);
|
||||
((HVX_Vector *) tb1)[r1_half] = Q6_V_hi_W(vp);
|
||||
tb0 += tb_step;
|
||||
tb1 += tb_step;
|
||||
}
|
||||
} else {
|
||||
const HVX_Vector vzero = Q6_V_vzero();
|
||||
for (int c = 0; c < head_dim; c += 64) {
|
||||
HVX_Vector v0 = *pv_in0++;
|
||||
HVX_VectorPair vp = Q6_W_vshuff_VVR(vzero, v0, -2);
|
||||
((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp);
|
||||
((HVX_Vector *) tb1)[r1_half] = Q6_V_hi_W(vp);
|
||||
tb0 += tb_step;
|
||||
tb1 += tb_step;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // HMX_UTILS_H
|
||||
|
||||
@@ -20,7 +20,7 @@ struct htp_mmap {
|
||||
uint64_t size;
|
||||
uint64_t base;
|
||||
uint32_t fd;
|
||||
uint32_t pinned;
|
||||
uint32_t reserved;
|
||||
};
|
||||
|
||||
// Scratchpad state
|
||||
@@ -77,6 +77,8 @@ struct htp_context {
|
||||
atomic_bool vtcm_valid;
|
||||
atomic_bool vtcm_needs_release;
|
||||
|
||||
uint64_t max_vmem;
|
||||
|
||||
struct htp_ops_context octx;
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
|
||||
@@ -90,15 +90,11 @@ enum htp_op_code {
|
||||
#define HTP_OP_MAX_INPUTS 6 // aka GGML_MAX_SRCS
|
||||
#define HTP_OP_MAX_PARAMS 16 // aka GGML_MAX_OP_PARAMS
|
||||
|
||||
#define HTP_OP_MAX_BUFS 8
|
||||
#define HTP_OP_MAX_BUFS 16
|
||||
#define HTP_OP_MAX_REQS 256
|
||||
#define HTP_OP_MAX_TENSORS (HTP_OP_MAX_REQS * HTP_OP_MAX_INPUTS + HTP_OP_MAX_REQS)
|
||||
|
||||
#if __HVX_ARCH__ < 75
|
||||
#define HTP_OP_MAX_VMEM (3167538380u)
|
||||
#else
|
||||
#define HTP_OP_MAX_VMEM (3221225472u)
|
||||
#endif
|
||||
#define HTP_OP_MAX_VMEM_DEFAULT (3355443200u)
|
||||
|
||||
#define HTP_MMAP_MAX_VMEM (2147483648u)
|
||||
|
||||
|
||||
@@ -11,9 +11,9 @@ struct htp_iface_pmu_conf {
|
||||
};
|
||||
|
||||
interface htp_iface : remote_handle64 {
|
||||
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx);
|
||||
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx, in uint64 max_vmem);
|
||||
AEEResult stop();
|
||||
AEEResult mmap(in uint32 fd, in uint32 size, in uint32 pinned);
|
||||
AEEResult mmap(in uint32 fd, in uint32 size);
|
||||
AEEResult munmap(in uint32 fd);
|
||||
AEEResult profiler(in uint32 mode, in htp_iface_pmu_conf pmu);
|
||||
AEEResult etm(in uint32 enable);
|
||||
|
||||
@@ -77,6 +77,12 @@ static inline int32_t hvx_vec_get_i32(HVX_Vector v) {
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline _Float16 hvx_vec_get_f16(HVX_Vector v) {
|
||||
_Float16 __attribute__((aligned(128))) x;
|
||||
hvx_vec_store_a(&x, 2, v);
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) {
|
||||
// abs by clearing the fp16 sign bit
|
||||
HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);
|
||||
|
||||
@@ -7,7 +7,8 @@
|
||||
|
||||
#include "hvx-base.h"
|
||||
|
||||
#define hvx_splat_loop_body(dst_type, vec_store) \
|
||||
#define hvx_splat_pragma(x) _Pragma(#x)
|
||||
#define hvx_splat_loop_body(dst_type, vec_store, unroll_cnt) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
\
|
||||
@@ -16,7 +17,7 @@
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
hvx_splat_pragma(unroll(unroll_cnt)) \
|
||||
for (; i < nvec; i++) { \
|
||||
vdst[i] = src; \
|
||||
} \
|
||||
@@ -25,31 +26,47 @@
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_splat_a(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
|
||||
static inline void hvx_splat_a(void * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a);
|
||||
hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a, 4);
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
|
||||
hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u);
|
||||
static inline void hvx_splat_u(void * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
|
||||
hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u, 4);
|
||||
}
|
||||
|
||||
static inline void hvx_splat_f32_a(uint8_t * restrict dst, float v, uint32_t n) {
|
||||
static inline void hvx_splat_f32_a(void * restrict dst, float v, uint32_t n) {
|
||||
hvx_splat_a(dst, hvx_vec_splat_f32(v), n, sizeof(float));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_f32_u(uint8_t * restrict dst, float v, uint32_t n) {
|
||||
static inline void hvx_splat_f32_u(void * restrict dst, float v, uint32_t n) {
|
||||
hvx_splat_u(dst, hvx_vec_splat_f32(v), n, sizeof(float));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_f16_a(uint8_t * restrict dst, _Float16 v, uint32_t n) {
|
||||
static inline void hvx_splat_f16_a(void * restrict dst, _Float16 v, uint32_t n) {
|
||||
hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_f16_u(uint8_t * restrict dst, _Float16 v, uint32_t n) {
|
||||
static inline void hvx_splat_f16_u(void * restrict dst, _Float16 v, uint32_t n) {
|
||||
hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u16_a(void * restrict dst, uint16_t v, uint32_t n) {
|
||||
hvx_splat_a(dst, Q6_Vh_vsplat_R(v), n, sizeof(uint16_t));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u16_u(void * restrict dst, uint16_t v, uint32_t n) {
|
||||
hvx_splat_u(dst, Q6_Vh_vsplat_R(v), n, sizeof(uint16_t));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u8_a(void * restrict dst, uint8_t v, uint32_t n) {
|
||||
hvx_splat_a(dst, Q6_Vb_vsplat_R(v), n, 1);
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u8_u(void * restrict dst, uint8_t v, uint32_t n) {
|
||||
hvx_splat_u(dst, Q6_Vb_vsplat_R(v), n, 1);
|
||||
}
|
||||
|
||||
#define hvx_copy_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
|
||||
#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
|
||||
#define EXP_ONE (0x3f800000) // 1.0
|
||||
#define EXP_RANGE_R (0x42B16666) // 88.7
|
||||
#define EXP_RANGE_R (0x42B17218) // ln(FLT_MAX) approx = 88.7228
|
||||
#define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN))
|
||||
|
||||
static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
|
||||
@@ -163,7 +163,7 @@ static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict
|
||||
HVX_Vector vec_out = Q6_V_vzero();
|
||||
|
||||
static const float kInf = INFINITY;
|
||||
static const float kMaxExp = 88.7f;
|
||||
static const float kMaxExp = 88.7228f;
|
||||
|
||||
const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp);
|
||||
const HVX_Vector inf = hvx_vec_splat_f32(kInf);
|
||||
|
||||
@@ -210,7 +210,7 @@ AEEResult htp_iface_close(remote_handle64 handle) {
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32 pinned) {
|
||||
AEEResult htp_iface_mmap(remote_handle64 handle, uint32_t fd, uint32_t size) {
|
||||
struct htp_context * ctx = (struct htp_context *) handle;
|
||||
if (!ctx) {
|
||||
return AEE_EBADPARM;
|
||||
@@ -220,7 +220,6 @@ AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32
|
||||
for (uint32_t i=0; i<HTP_MAX_MMAPS; i++) {
|
||||
struct htp_mmap *m = &ctx->mmap[i];
|
||||
if (m->fd == fd) {
|
||||
m->pinned = pinned;
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
}
|
||||
@@ -229,7 +228,7 @@ AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32
|
||||
for (uint32_t i=0; i<HTP_MAX_MMAPS; i++) {
|
||||
struct htp_mmap *m = &ctx->mmap[i];
|
||||
if (!m->size) {
|
||||
FARF(HIGH, "mmap : fd %u size %u pinned %u", fd, size, pinned);
|
||||
FARF(HIGH, "mmap : fd %u size %u", fd, size);
|
||||
#if __HVX_ARCH__ > 73
|
||||
void *va = HAP_mmap2(NULL, size, HAP_PROT_READ | HAP_PROT_WRITE, 0, fd, 0);
|
||||
#else
|
||||
@@ -248,7 +247,6 @@ AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32
|
||||
m->base = (uint64_t) va;
|
||||
m->fd = fd;
|
||||
m->size = size;
|
||||
m->pinned = pinned;
|
||||
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
@@ -275,7 +273,6 @@ AEEResult htp_iface_munmap(remote_handle64 handle, uint32 fd) {
|
||||
m->size = 0;
|
||||
m->base = NULL;
|
||||
m->fd = -1;
|
||||
m->pinned = 0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -358,7 +355,7 @@ static void vtcm_free(struct htp_context * ctx) {
|
||||
static void htp_packet_callback(dspqueue_t queue, int error, void * context);
|
||||
static void htp_error_callback(dspqueue_t queue, int error, void * context);
|
||||
|
||||
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx) {
|
||||
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx, uint64_t max_vmem) {
|
||||
struct htp_context * ctx = (struct htp_context *) handle;
|
||||
|
||||
if (!ctx) {
|
||||
@@ -376,12 +373,12 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
|
||||
htp_error_callback, // Error callback; no errors expected on the DSP
|
||||
(void *) ctx, // Callback context
|
||||
&ctx->queue);
|
||||
|
||||
if (err) {
|
||||
FARF(ERROR, "Queue import failed with 0x%08x", (unsigned) err);
|
||||
return err;
|
||||
}
|
||||
|
||||
ctx->max_vmem = max_vmem;
|
||||
ctx->thread_id = qurt_thread_get_id();
|
||||
ctx->thread_prio = qurt_thread_get_priority(ctx->thread_id);
|
||||
|
||||
@@ -622,8 +619,8 @@ static inline bool reuse_buf(struct htp_context *ctx, uint32_t *m_reuse, struct
|
||||
}
|
||||
|
||||
static inline void drop_mmap(struct htp_context *ctx, struct htp_mmap *m) {
|
||||
if (m->size && !m->pinned) {
|
||||
FARF(HIGH, "unmap : fd %u base %p size %u pinned %u", m->fd, (void*) m->base, (uint32_t) m->size, m->pinned);
|
||||
if (m->size) {
|
||||
FARF(HIGH, "unmap : fd %u base %p size %u", m->fd, (void*) m->base, (uint32_t) m->size);
|
||||
#if __HVX_ARCH__ > 73
|
||||
HAP_munmap2((void *) m->base, m->size);
|
||||
#else
|
||||
@@ -660,9 +657,8 @@ static inline void mmap_buf(struct htp_context *ctx, struct htp_buf_desc *b) {
|
||||
m->base = b->base = (uint64_t) va;
|
||||
m->fd = b->fd;
|
||||
m->size = b->size;
|
||||
m->pinned = 0;
|
||||
|
||||
FARF(HIGH, "mmap : fd %u base %p size %u pinned %u", m->fd, (void*) m->base, (uint32_t) m->size, m->pinned);
|
||||
FARF(HIGH, "mmap : fd %u base %p size %u", m->fd, (void*) m->base, (uint32_t) m->size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -672,8 +668,8 @@ static void prep_op_bufs(struct htp_context *ctx, struct htp_buf_desc *bufs, uin
|
||||
uint32_t m_reuse = 0; // mmap reuse mask (index from ctx->mmap array)
|
||||
uint32_t b_reuse = 0; // buf reuse count
|
||||
|
||||
size_t m_vmem = 0; // mapped vmem
|
||||
size_t e_vmem = 0; // extra vmem
|
||||
uint64_t m_vmem = 0; // mapped vmem
|
||||
uint64_t e_vmem = 0; // extra vmem
|
||||
|
||||
// See what we can reuse
|
||||
for (uint32_t i=0; i < n_bufs; i++) {
|
||||
@@ -687,9 +683,10 @@ static void prep_op_bufs(struct htp_context *ctx, struct htp_buf_desc *bufs, uin
|
||||
// See how much vmem we have mmaped right now
|
||||
for (uint32_t i=0; i<HTP_MAX_MMAPS; i++) { m_vmem += ctx->mmap[i].size; }
|
||||
|
||||
FARF(HIGH, "prep-bufs : pass1 mmap-vmem %zu extra-vmem %zu n-bufs %u b-reuse %u", m_vmem, e_vmem, n_bufs, b_reuse);
|
||||
FARF(HIGH, "prep-bufs : pass1 mmap-vmem %zu extra-vmem %zu max-vmem %zu : n-bufs %u b-reuse %u",
|
||||
(size_t) m_vmem, (size_t) e_vmem, (size_t) ctx->max_vmem, n_bufs, b_reuse);
|
||||
|
||||
if ((m_vmem + e_vmem) > HTP_OP_MAX_VMEM) {
|
||||
if ((m_vmem + e_vmem) > ctx->max_vmem) {
|
||||
// Drop unused mappings
|
||||
for (uint32_t i=0; i < HTP_MAX_MMAPS; i++) {
|
||||
bool used = m_reuse & (1<<i);
|
||||
|
||||
@@ -2991,12 +2991,10 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
// M alignment: when M > 32 but not 32-aligned, we split into
|
||||
// HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows).
|
||||
// When M <= 32 and not 32-aligned, fall back entirely to HVX.
|
||||
// M alignment: Use HMX when M >= 32, the last partial tile (m_total % 32 rows)
|
||||
// is handled by HMX itself; when M < 32 fall back to HVX.
|
||||
const int m_total = (int) src1->ne[1];
|
||||
const int m_tail = m_total % 32;
|
||||
const int m_hmx = m_total - m_tail;
|
||||
const int m_hmx = m_total & ~31; // 0 when M < 32
|
||||
|
||||
if (m_hmx == 0) {
|
||||
return op_matmul_hvx(octx);
|
||||
@@ -3009,7 +3007,6 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
int k = (int) src0->ne[0]; // inner dimension
|
||||
int n = (int) src0->ne[1]; // weight columns
|
||||
|
||||
// --- Phase 1: HMX on the first m_hmx (32-aligned) rows ---
|
||||
int ret = -1;
|
||||
|
||||
// Row strides in elements. For compact tensors these equal k; for
|
||||
@@ -3027,7 +3024,7 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
.dst = (float *) dst->data,
|
||||
.activation = (float *) src1->data,
|
||||
.permuted_weight = (const __fp16 *) src0->data,
|
||||
.m = m_hmx,
|
||||
.m = m_total,
|
||||
.k = k,
|
||||
.n = n,
|
||||
.act_stride = act_stride,
|
||||
@@ -3048,12 +3045,12 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
} else {
|
||||
ret = hmx_mat_mul_permuted_w16a32(octx->ctx,
|
||||
(float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data,
|
||||
m_hmx, k, n, act_stride, wgt_stride);
|
||||
m_total, k, n, act_stride, wgt_stride);
|
||||
}
|
||||
} else {
|
||||
ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx,
|
||||
(float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data,
|
||||
m_hmx, k, n, (int) src0->type);
|
||||
m_total, k, n, (int) src0->type);
|
||||
}
|
||||
|
||||
if (ret != 0) {
|
||||
@@ -3061,27 +3058,6 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
return op_matmul(octx);
|
||||
}
|
||||
|
||||
// --- Phase 2: HVX on the remaining m_tail rows ---
|
||||
if (m_tail > 0) {
|
||||
// copy of src1 and dst
|
||||
struct htp_tensor src1_tail = *src1;
|
||||
struct htp_tensor dst_tail = *dst;
|
||||
|
||||
src1_tail.ne[1] = m_tail; // only tail rows
|
||||
dst_tail.ne[1] = m_tail; // only tail rows
|
||||
|
||||
// Offset activation and dst pointers past the HMX-processed rows.
|
||||
// Use nb[1] (row stride in bytes) to compute the byte offset.
|
||||
src1_tail.data += (uint32_t) m_hmx * src1->nb[1];
|
||||
dst_tail.data += (uint32_t) m_hmx * dst->nb[1];
|
||||
|
||||
octx->src[1] = &src1_tail;
|
||||
octx->dst = &dst_tail;
|
||||
|
||||
FARF(HIGH, "hmx-matmul: HVX tail m_tail %d src1 %p dst %p", m_tail, (void *) src1_tail.data, (void *) dst_tail.data);
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
return 0;
|
||||
#endif // HTP_HAS_HMX
|
||||
}
|
||||
|
||||
@@ -26,8 +26,8 @@ struct htp_unary_context {
|
||||
const uint8_t * data_src0;
|
||||
uint8_t * data_dst;
|
||||
|
||||
size_t src0_row_size;
|
||||
size_t dst_row_size;
|
||||
size_t src0_data_row_size; // actual data bytes per row
|
||||
size_t dst_data_row_size; // actual data bytes per row
|
||||
|
||||
size_t src0_row_size_aligned;
|
||||
size_t dst_row_size_aligned;
|
||||
@@ -41,6 +41,40 @@ struct htp_unary_context {
|
||||
uint32_t nc;
|
||||
};
|
||||
|
||||
// Convert flat row index to DDR byte offset using the tensor's actual strides.
|
||||
// ir = i1 + ne1*(i2 + ne2*i3) => offset = i1*nb1 + i2*nb2 + i3*nb3
|
||||
static inline size_t unary_row_offset(uint32_t ir,
|
||||
uint32_t ne1, uint32_t ne2,
|
||||
size_t nb1, size_t nb2, size_t nb3) {
|
||||
const uint32_t i1 = ir % ne1;
|
||||
const uint32_t i2 = (ir / ne1) % ne2;
|
||||
const uint32_t i3 = ir / (ne1 * ne2);
|
||||
return i1 * nb1 + i2 * nb2 + i3 * nb3;
|
||||
}
|
||||
// Safe DMA block size from row `ir`: clamp to the tighter dim-1 slice
|
||||
// boundary of src and dst so the nb1 stride stays valid for all rows.
|
||||
static inline uint32_t unary_block_size(uint32_t ir,
|
||||
uint32_t end_row,
|
||||
uint32_t block,
|
||||
bool src_contig,
|
||||
bool dst_contig,
|
||||
uint32_t src_ne1,
|
||||
uint32_t dst_ne1) {
|
||||
uint32_t limit = MIN(block, end_row - ir);
|
||||
|
||||
if (!src_contig) {
|
||||
const uint32_t src_slice_end = (ir / src_ne1 + 1) * src_ne1;
|
||||
limit = MIN(limit, src_slice_end - ir);
|
||||
}
|
||||
|
||||
if (!dst_contig) {
|
||||
const uint32_t dst_slice_end = (ir / dst_ne1 + 1) * dst_ne1;
|
||||
limit = MIN(limit, dst_slice_end - ir);
|
||||
}
|
||||
|
||||
return limit;
|
||||
}
|
||||
|
||||
#define htp_unary_preamble \
|
||||
const uint32_t ne00 = src->ne[0]; \
|
||||
const uint32_t ne01 = src->ne[1]; \
|
||||
@@ -276,8 +310,8 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
int32_t * op_params = octx->op_params;
|
||||
uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread;
|
||||
|
||||
const size_t src0_row_size = uctx->src0_row_size;
|
||||
const size_t dst_row_size = uctx->dst_row_size;
|
||||
const size_t src0_data_row_size = uctx->src0_data_row_size;
|
||||
const size_t dst_data_row_size = uctx->dst_data_row_size;
|
||||
|
||||
const size_t src0_row_size_aligned = uctx->src0_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = uctx->dst_row_size_aligned;
|
||||
@@ -303,7 +337,16 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
size_t src0_spad_half_size = uctx->src0_spad_half_size;
|
||||
size_t dst_spad_half_size = uctx->dst_spad_half_size;
|
||||
|
||||
const int BLOCK = uctx->block;
|
||||
// Non-contiguous tensors have gaps at dim-2/3 boundaries that a single-stride
|
||||
// 2D DMA descriptor cannot span. Clamp BLOCK to ne1 (one dim-1 slice) so every
|
||||
// transfer stays within a nb1-uniform region. Skipped for contiguous tensors.
|
||||
const bool src0_contig = (nb02 == (size_t)ne01 * nb01) &&
|
||||
(nb03 == (size_t)ne02 * nb02);
|
||||
const bool dst_contig = (nb2 == (size_t)ne1 * nb1) &&
|
||||
(nb3 == (size_t)ne2 * nb2);
|
||||
const uint32_t src0_max_block = src0_contig ? uctx->block : MIN((uint32_t)uctx->block, ne01);
|
||||
const uint32_t dst_max_block = dst_contig ? uctx->block : MIN((uint32_t)uctx->block, ne1);
|
||||
const uint32_t BLOCK = MIN(src0_max_block, dst_max_block);
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
@@ -312,21 +355,23 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
|
||||
dma_queue * dma_queue = octx->ctx->dma[ith];
|
||||
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; spad_idx++) {
|
||||
const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
||||
|
||||
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_queue_push(dma_queue,
|
||||
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
|
||||
dst_row_size, dst_row_size_aligned, 0);
|
||||
nb1, dst_row_size_aligned, dst_data_row_size, 0);
|
||||
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, block_size);
|
||||
const size_t src0_off = unary_row_offset(ir, ne01, ne02, nb01, nb02, nb03);
|
||||
dma_queue_push(dma_queue,
|
||||
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + src0_off),
|
||||
src0_row_size_aligned, nb01, src0_data_row_size, block_size);
|
||||
ir += block_size;
|
||||
}
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ) {
|
||||
const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
||||
|
||||
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
|
||||
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
|
||||
@@ -361,18 +406,25 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
break;
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),
|
||||
dst_row_size, dst_row_size_aligned, block_size);
|
||||
const size_t dst_off = unary_row_offset(ir, ne1, ne2, nb1, nb2, nb3);
|
||||
dma_queue_push(dma_queue,
|
||||
dma_make_ptr(data_dst + dst_off, dst_spad),
|
||||
nb1, dst_row_size_aligned, dst_data_row_size, block_size);
|
||||
|
||||
// prefetch N+2 loop iteration if any
|
||||
const uint32_t pref_block = (ir + BLOCK * 2);
|
||||
if (pref_block < src0_end_row) {
|
||||
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, pref_block_size);
|
||||
const uint32_t next_ir = ir + block_size;
|
||||
if (next_ir < src0_end_row) {
|
||||
const uint32_t next_block_size = unary_block_size(next_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
||||
const uint32_t pref_ir = next_ir + next_block_size;
|
||||
if (pref_ir < src0_end_row) {
|
||||
const uint32_t pref_block_size = unary_block_size(pref_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
||||
const size_t src0_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03);
|
||||
dma_queue_push(dma_queue,
|
||||
dma_make_ptr(src0_spad, data_src + src0_pref_off),
|
||||
src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size);
|
||||
}
|
||||
}
|
||||
ir += block_size;
|
||||
}
|
||||
|
||||
dma_queue_flush(dma_queue);
|
||||
@@ -426,11 +478,11 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
|
||||
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
const size_t src0_data_row_size = src0->ne[0] * sizeof(float);
|
||||
const size_t dst_data_row_size = dst->ne[0] * sizeof(float);
|
||||
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_data_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_data_row_size, VLEN);
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
@@ -468,8 +520,8 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
.data_src0 = (const uint8_t *)src0->data,
|
||||
.data_dst = (uint8_t *)dst->data,
|
||||
|
||||
.src0_row_size = src0_row_size,
|
||||
.dst_row_size = dst_row_size,
|
||||
.src0_data_row_size = src0_data_row_size,
|
||||
.dst_data_row_size = dst_data_row_size,
|
||||
|
||||
.src0_row_size_aligned = src0_row_size_aligned,
|
||||
.dst_row_size_aligned = dst_row_size_aligned,
|
||||
|
||||
16
ggml/src/ggml-hexagon/htp/vtcm-utils.h
Normal file
16
ggml/src/ggml-hexagon/htp/vtcm-utils.h
Normal file
@@ -0,0 +1,16 @@
|
||||
#ifndef VTCM_UTILS_H
|
||||
#define VTCM_UTILS_H
|
||||
|
||||
#include "hex-utils.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <hexagon_types.h>
|
||||
|
||||
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {
|
||||
uint8_t *p = *vtcm_ptr;
|
||||
*vtcm_ptr += size;
|
||||
return p;
|
||||
}
|
||||
|
||||
#endif // VTCM_UTILS_H
|
||||
@@ -282,6 +282,7 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf);
|
||||
void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
|
||||
void ggml_metal_buffer_set_tensor (ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||
void ggml_metal_buffer_get_tensor (ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||
bool ggml_metal_buffer_cpy_tensor (ggml_metal_buffer_t buf, const struct ggml_tensor * src, struct ggml_tensor * dst);
|
||||
void ggml_metal_buffer_clear (ggml_metal_buffer_t buf, uint8_t value);
|
||||
|
||||
// finds the Metal buffer that contains the tensor data on the GPU device
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#import "ggml-metal-device.h"
|
||||
|
||||
#import "ggml-impl.h"
|
||||
#import "ggml-backend-impl.h"
|
||||
|
||||
#include <Foundation/Foundation.h>
|
||||
|
||||
@@ -1737,6 +1738,47 @@ void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_ten
|
||||
}
|
||||
}
|
||||
|
||||
bool ggml_metal_buffer_cpy_tensor(ggml_metal_buffer_t buf_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) {
|
||||
ggml_metal_buffer_t buf_src = (ggml_metal_buffer_t)src->buffer->context;
|
||||
|
||||
const size_t size = ggml_nbytes(src);
|
||||
|
||||
// if both buffers are shared, we can use memcpy directly
|
||||
if (buf_dst->is_shared && buf_src->is_shared) {
|
||||
memcpy(dst->data, src->data, size);
|
||||
return true;
|
||||
}
|
||||
|
||||
// for private buffers, we need to use Metal blit commands
|
||||
@autoreleasepool {
|
||||
struct ggml_metal_buffer_id bid_src = ggml_metal_buffer_get_id(buf_src, src);
|
||||
struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf_dst, dst);
|
||||
|
||||
if (bid_src.metal == nil || bid_dst.metal == nil) {
|
||||
return false;
|
||||
}
|
||||
|
||||
id<MTLCommandBuffer> cmd_buf = [buf_dst->dev->mtl_queue commandBufferWithUnretainedReferences];
|
||||
|
||||
{
|
||||
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
||||
|
||||
[encoder copyFromBuffer:bid_src.metal
|
||||
sourceOffset:bid_src.offs
|
||||
toBuffer:bid_dst.metal
|
||||
destinationOffset:bid_dst.offs
|
||||
size:size];
|
||||
|
||||
[encoder endEncoding];
|
||||
}
|
||||
|
||||
[cmd_buf commit];
|
||||
[cmd_buf waitUntilCompleted];
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void ggml_metal_buffer_clear(ggml_metal_buffer_t buf, uint8_t value) {
|
||||
if (buf->is_shared) {
|
||||
memset(buf->all_data, value, buf->all_size);
|
||||
|
||||
@@ -17,6 +17,9 @@
|
||||
// note: can be overridden with GGML_METAL_DEVICES env to simulate virtual devices
|
||||
static int g_devices = 1;
|
||||
|
||||
// forward declaration
|
||||
static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// backend interface
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -68,11 +71,11 @@ static bool ggml_backend_metal_buffer_shared_cpy_tensor(ggml_backend_buffer_t bu
|
||||
|
||||
GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(dst);
|
||||
if (!ggml_backend_buffer_is_metal(src->buffer)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
return ggml_metal_buffer_cpy_tensor(ctx, src, dst);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||
@@ -144,11 +147,11 @@ static bool ggml_backend_metal_buffer_private_cpy_tensor(ggml_backend_buffer_t b
|
||||
|
||||
GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(dst);
|
||||
if (!ggml_backend_buffer_is_metal(src->buffer)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
return ggml_metal_buffer_cpy_tensor(ctx, src, dst);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||
@@ -166,8 +169,8 @@ static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = {
|
||||
/* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor,
|
||||
/* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor,
|
||||
/* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor,
|
||||
/* .get_tensor_2d_async = */ NULL,
|
||||
/* .set_tensor_2d_async = */ NULL,
|
||||
/* .set_tensor_2d = */ NULL,
|
||||
/* .get_tensor_2d = */ NULL,
|
||||
/* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor,
|
||||
/* .clear = */ ggml_backend_metal_buffer_private_clear,
|
||||
/* .reset = */ NULL,
|
||||
@@ -567,8 +570,8 @@ static ggml_backend_i ggml_backend_metal_i = {
|
||||
/* .free = */ ggml_backend_metal_free,
|
||||
/* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
|
||||
/* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
|
||||
/* .get_tensor_2d_async = */ NULL,
|
||||
/* .set_tensor_2d_async = */ NULL,
|
||||
/* .get_tensor_2d_async = */ NULL,
|
||||
/* .cpy_tensor_async = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups
|
||||
/* .synchronize = */ ggml_backend_metal_synchronize,
|
||||
/* .graph_plan_create = */ NULL,
|
||||
|
||||
@@ -66,8 +66,6 @@ set(GGML_OPENCL_KERNELS
|
||||
diag
|
||||
div
|
||||
gelu
|
||||
gemv_noshuffle_general
|
||||
gemv_noshuffle
|
||||
get_rows
|
||||
glu
|
||||
group_norm
|
||||
@@ -75,7 +73,6 @@ set(GGML_OPENCL_KERNELS
|
||||
im2col_f32
|
||||
im2col_f16
|
||||
mean
|
||||
mul_mat_Ab_Bi_8x4
|
||||
mul_mv_f16_f16
|
||||
mul_mv_f16_f32_1row
|
||||
mul_mv_f16_f32_l4
|
||||
@@ -107,6 +104,10 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mv_id_mxfp4_f32_flat
|
||||
gemm_moe_mxfp4_f32
|
||||
gemv_moe_mxfp4_f32
|
||||
gemm_moe_mxfp4_f32_ns
|
||||
gemv_moe_mxfp4_f32_ns
|
||||
moe_reorder_b
|
||||
moe_sort_by_expert
|
||||
mul_mm_f32_f32_l4_lm
|
||||
mul_mm_f16_f32_l4_lm
|
||||
mul_mm_q4_0_f32_l4_lm
|
||||
@@ -116,12 +117,15 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mm_q4_k_f32_l4_lm
|
||||
mul_mm_q5_k_f32_l4_lm
|
||||
mul_mm_q6_k_f32_l4_lm
|
||||
mul_mm_q8_0_f32_8x4
|
||||
gemv_noshuffle_q4_0_f32
|
||||
gemv_noshuffle_q4_0_f32_spec
|
||||
gemm_noshuffle_q4_0_f32
|
||||
gemv_noshuffle_q4_1_f32
|
||||
gemm_noshuffle_q4_1_f32
|
||||
gemv_noshuffle_iq4_nl_f32
|
||||
gemm_noshuffle_iq4_nl_f32
|
||||
gemv_noshuffle_general_q8_0_f32
|
||||
gemv_noshuffle_q8_0_f32
|
||||
gemm_noshuffle_q8_0_f32
|
||||
gemv_noshuffle_q4_k_f32
|
||||
gemm_noshuffle_q4_k_f32
|
||||
gemv_noshuffle_q6_k_f32
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -371,6 +371,93 @@ kernel void kernel_restore_block_mxfp4_trans(
|
||||
b->e = src_e[src_blk_offset];
|
||||
}
|
||||
|
||||
kernel void kernel_convert_block_mxfp4_trans4_ns(
|
||||
global struct block_mxfp4 * src0,
|
||||
__global uint * dst_q,
|
||||
__global uchar * dst_e,
|
||||
uint ne00,
|
||||
uint ne01
|
||||
) {
|
||||
uint i00 = get_global_id(1);
|
||||
uint i01 = get_global_id(0);
|
||||
uint i02 = get_global_id(2);
|
||||
|
||||
uint ne00_blk = ne00 / QK_MXFP4;
|
||||
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
|
||||
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
|
||||
|
||||
global struct block_mxfp4 * b = src0 + src_blk_offset;
|
||||
dst_e[dst_blk_offset] = b->e;
|
||||
|
||||
// extract quantization and unshuffle
|
||||
ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0];
|
||||
|
||||
ushort8 post_block = (ushort8)(0);
|
||||
|
||||
uchar * pre_block_ptr = (uchar *)(&pre_block);
|
||||
uchar * post_block_ptr = (uchar *)(&post_block);
|
||||
|
||||
for (int i = 0; i < QK_MXFP4 / 4; ++i) {
|
||||
uchar x0 = pre_block_ptr[2*i + 0];
|
||||
uchar x1 = pre_block_ptr[2*i + 1];
|
||||
|
||||
post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
|
||||
post_block_ptr[i + QK_MXFP4 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
|
||||
}
|
||||
|
||||
uint4 q_block = as_uint4(post_block);
|
||||
|
||||
uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01;
|
||||
dst_q[offset] = q_block.x;
|
||||
dst_q[offset + ne01] = q_block.y;
|
||||
dst_q[offset + ne01 * 2] = q_block.z;
|
||||
dst_q[offset + ne01 * 3] = q_block.w;
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_mxfp4_trans4_ns(
|
||||
__global uint * src_q,
|
||||
__global uchar * src_e,
|
||||
__global struct block_mxfp4 * dst0,
|
||||
uint ne00,
|
||||
uint ne01
|
||||
) {
|
||||
uint i00 = get_global_id(1);
|
||||
uint i01 = get_global_id(0);
|
||||
uint i02 = get_global_id(2);
|
||||
|
||||
uint ne00_blk = ne00 / QK_MXFP4;
|
||||
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
|
||||
uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
|
||||
|
||||
__global struct block_mxfp4 * b = dst0 + dst_blk_offset;
|
||||
b->e = src_e[src_d_offset];
|
||||
|
||||
// collect transposed quantization parts for a block
|
||||
uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01;
|
||||
uint4 q_block;
|
||||
q_block.x = src_q[src_q_offset];
|
||||
q_block.y = src_q[src_q_offset + ne01];
|
||||
q_block.z = src_q[src_q_offset + ne01 * 2];
|
||||
q_block.w = src_q[src_q_offset + ne01 * 3];
|
||||
|
||||
ushort8 post_block = as_ushort8(q_block);
|
||||
ushort8 pre_block = (ushort8)(0);
|
||||
|
||||
uchar * pre_block_ptr = (uchar *)(&pre_block);
|
||||
uchar * post_block_ptr = (uchar *)(&post_block);
|
||||
|
||||
for (int i = 0; i < QK_MXFP4 / 4; ++i) {
|
||||
uchar x0 = post_block_ptr[i + 0];
|
||||
uchar x1 = post_block_ptr[i + QK_MXFP4 / 4];
|
||||
|
||||
pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
|
||||
pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
|
||||
}
|
||||
|
||||
((__global ushort8 *)(&(b->qs[0])))[0] = pre_block;
|
||||
}
|
||||
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q8_0
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
302
ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl
Normal file
302
ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl
Normal file
@@ -0,0 +1,302 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
|
||||
|
||||
#define TILESIZE_K 16
|
||||
#define TILESIZE_M 64
|
||||
#define TILESIZE_N 32
|
||||
|
||||
|
||||
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) {
|
||||
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
|
||||
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
|
||||
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
|
||||
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
|
||||
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
|
||||
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
|
||||
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
|
||||
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s0 & 0x8000;
|
||||
|
||||
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
|
||||
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
|
||||
|
||||
ushort2 fp16_packed_a_1, fp16_packed_b_1;
|
||||
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
|
||||
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
|
||||
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
|
||||
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
|
||||
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
|
||||
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
|
||||
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s1 & 0x8000;
|
||||
|
||||
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
|
||||
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
|
||||
|
||||
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
|
||||
}
|
||||
|
||||
|
||||
#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \
|
||||
acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \
|
||||
acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \
|
||||
acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \
|
||||
acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \
|
||||
acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \
|
||||
acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \
|
||||
acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \
|
||||
acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \
|
||||
acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \
|
||||
acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \
|
||||
acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \
|
||||
acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \
|
||||
acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \
|
||||
acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \
|
||||
acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \
|
||||
acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \
|
||||
acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \
|
||||
acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \
|
||||
acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \
|
||||
acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \
|
||||
acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \
|
||||
acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \
|
||||
acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \
|
||||
acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \
|
||||
acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \
|
||||
acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \
|
||||
acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \
|
||||
acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \
|
||||
acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \
|
||||
acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \
|
||||
acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \
|
||||
acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \
|
||||
c_reg.lo += convert_float8(acc.lo); \
|
||||
c_reg.hi += convert_float8(acc.hi); \
|
||||
acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \
|
||||
acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \
|
||||
acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \
|
||||
acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \
|
||||
acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \
|
||||
acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \
|
||||
acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \
|
||||
acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \
|
||||
acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \
|
||||
acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \
|
||||
acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \
|
||||
acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \
|
||||
acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \
|
||||
acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \
|
||||
acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \
|
||||
acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \
|
||||
acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \
|
||||
acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \
|
||||
acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \
|
||||
acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \
|
||||
acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \
|
||||
acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \
|
||||
acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \
|
||||
acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \
|
||||
acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \
|
||||
acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \
|
||||
acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \
|
||||
acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \
|
||||
acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \
|
||||
acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \
|
||||
acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \
|
||||
acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \
|
||||
c_reg.lo += convert_float8(acc.lo); \
|
||||
c_reg.hi += convert_float8(acc.hi); \
|
||||
|
||||
|
||||
static inline half e8m0_to_fp16(uchar x) {
|
||||
ushort bits;
|
||||
bits = (ushort)(x) - (ushort)(112);
|
||||
bits = ((bits & 0x00E0) != 0) ? 0x7C00 : (bits << 10);
|
||||
return as_half(bits);
|
||||
}
|
||||
|
||||
static inline float e8m0_to_fp32(uchar x) {
|
||||
int bits;
|
||||
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
|
||||
return as_float(bits);
|
||||
}
|
||||
|
||||
|
||||
__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair
|
||||
kernel void kernel_gemm_moe_mxfp4_f32_ns(
|
||||
__read_only image1d_buffer_t src0_q,
|
||||
__global uchar * src0_d,
|
||||
__read_only image1d_buffer_t src1,
|
||||
__global uint * src2,
|
||||
__global ushort * src2_emap,
|
||||
__write_only image1d_buffer_t dst,
|
||||
__global int * total_tiles,
|
||||
uint ne00,
|
||||
uint ne01
|
||||
) {
|
||||
uint block_id_m = get_global_id(1); // m_tile
|
||||
uint block_id_n = get_global_id(2); // n_tile
|
||||
|
||||
// Boundary check
|
||||
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
|
||||
return;
|
||||
}
|
||||
|
||||
__private half16 reg_a;
|
||||
__private float32 reg_c = (float32)(0);
|
||||
__local half4 shared_b[128];
|
||||
|
||||
const ushort expert_id = src2_emap[block_id_n];
|
||||
|
||||
const uint row = block_id_m * TILESIZE_M;
|
||||
const uint col = block_id_n * TILESIZE_N;
|
||||
|
||||
uint sub_block_id_m = get_local_id(0);
|
||||
uint2 b_global_offset;
|
||||
b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00;
|
||||
b_global_offset.y = b_global_offset.x + (16 * ne00);
|
||||
uint2 b_local_offset;
|
||||
b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2);
|
||||
b_local_offset.y = b_local_offset.x + 16;
|
||||
|
||||
// Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks
|
||||
for (uint step = 0; step < ne00; step += TILESIZE_K * 2) {
|
||||
// First sub-block
|
||||
uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
|
||||
uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5);
|
||||
uint b_sub_offset = col * ne00 + step;
|
||||
|
||||
// Load scale for current mxfp4 block
|
||||
uint s_offset = s_sub_offset + get_global_id(0);
|
||||
float s = e8m0_to_fp32(src0_d[s_offset]);
|
||||
|
||||
// Load 16 fp4 (64-bits) in transposed layout
|
||||
uint2 mxfp4x16;
|
||||
mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
|
||||
mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
|
||||
|
||||
// Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements
|
||||
float8 bx8_f32;
|
||||
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
|
||||
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
|
||||
// Convert to half and store to LM to share within the subgroup
|
||||
half8 bx8_f16 = convert_half8(bx8_f32);
|
||||
shared_b[b_local_offset.x] = bx8_f16.lo;
|
||||
shared_b[b_local_offset.y] = bx8_f16.hi;
|
||||
|
||||
// Dequantization
|
||||
reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s;
|
||||
reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s;
|
||||
|
||||
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// 32 16x16 fp16 dot product with 8 elements reduction for better precision
|
||||
half16 acc;
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
|
||||
|
||||
// Repeat for second sub-block
|
||||
uint half_step = step + TILESIZE_K;
|
||||
q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
|
||||
b_sub_offset = col * ne00 + half_step;
|
||||
|
||||
// Load next 16 fp4 (64-bits) in transposed layout
|
||||
mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
|
||||
mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
|
||||
|
||||
// Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements
|
||||
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
|
||||
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
|
||||
// Convert to half and store to LM to share within the subgroup
|
||||
bx8_f16 = convert_half8(bx8_f32);
|
||||
shared_b[b_local_offset.x] = bx8_f16.lo;
|
||||
shared_b[b_local_offset.y] = bx8_f16.hi;
|
||||
|
||||
// Dequantization
|
||||
reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s;
|
||||
reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s;
|
||||
|
||||
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// 32 16x16 fp16 dot product with 3-levels reduction for better precision
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
|
||||
}
|
||||
|
||||
// Load poster router and share in LM
|
||||
__local uint out_idx[TILESIZE_N];
|
||||
|
||||
if (get_local_id(0) < TILESIZE_N) {
|
||||
uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)];
|
||||
if (idx == 0xFFFFFFFF) {
|
||||
idx = src2[block_id_n * TILESIZE_N + 0];
|
||||
}
|
||||
out_idx[get_local_id(0)] = idx * ne01;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// Scatter results back to original position in output grid
|
||||
uint m_offset = row + get_local_id(0);
|
||||
|
||||
write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1));
|
||||
write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2));
|
||||
write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3));
|
||||
write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4));
|
||||
write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5));
|
||||
write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6));
|
||||
write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7));
|
||||
write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8));
|
||||
write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9));
|
||||
write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa));
|
||||
write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb));
|
||||
write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc));
|
||||
write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd));
|
||||
write_imagef(dst, out_idx[14] + m_offset, (reg_c.se));
|
||||
write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf));
|
||||
write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg));
|
||||
write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh));
|
||||
write_imagef(dst, out_idx[18] + m_offset, (reg_c.si));
|
||||
write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj));
|
||||
write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk));
|
||||
write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl));
|
||||
write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm));
|
||||
write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn));
|
||||
write_imagef(dst, out_idx[24] + m_offset, (reg_c.so));
|
||||
write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp));
|
||||
write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq));
|
||||
write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr));
|
||||
write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss));
|
||||
write_imagef(dst, out_idx[29] + m_offset, (reg_c.st));
|
||||
write_imagef(dst, out_idx[30] + m_offset, (reg_c.su));
|
||||
write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv));
|
||||
|
||||
// Store zero padding parts to the index of first output in tile, override correct result in the end
|
||||
barrier(CLK_GLOBAL_MEM_FENCE);
|
||||
write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0));
|
||||
}
|
||||
@@ -17,7 +17,7 @@
|
||||
REQD_SUBGROUP_SIZE_128
|
||||
#endif
|
||||
|
||||
kernel void kernel_mul_mat_Ab_Bi_8x4(
|
||||
kernel void kernel_gemm_noshuffle_q4_0_f32(
|
||||
global const ushort * src0_q, // quantized A
|
||||
global const half * src0_d, // A scales
|
||||
__read_only image1d_buffer_t src1, // B (1d image)
|
||||
@@ -11,7 +11,7 @@
|
||||
REQD_SUBGROUP_SIZE_128
|
||||
#endif
|
||||
|
||||
kernel void kernel_mul_mm_q8_0_f32_8x4(
|
||||
kernel void kernel_gemm_noshuffle_q8_0_f32(
|
||||
global const uint * src0_q,
|
||||
global const half * src0_d,
|
||||
__read_only image1d_buffer_t src1,
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user