mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-05-14 17:07:43 +03:00
Compare commits
110 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e20b83930c | ||
|
|
fd89556567 | ||
|
|
60489932ec | ||
|
|
4a4f819cb6 | ||
|
|
046e284437 | ||
|
|
66001722aa | ||
|
|
c5703e03a5 | ||
|
|
b46812de78 | ||
|
|
49956041ee | ||
|
|
9f5f0e689c | ||
|
|
f9cd456ea5 | ||
|
|
5d6f18a638 | ||
|
|
29debb3a6a | ||
|
|
9dcf835528 | ||
|
|
58e68df0f9 | ||
|
|
9b2925e1e0 | ||
|
|
a8fd165fec | ||
|
|
6d57a49a70 | ||
|
|
3e941b813b | ||
|
|
f3e8d149ce | ||
|
|
1d72d87349 | ||
|
|
6a2a2513dc | ||
|
|
44dbe8c521 | ||
|
|
05ff59cb57 | ||
|
|
aaf4a4d5e0 | ||
|
|
e43431b381 | ||
|
|
ceb7e14b96 | ||
|
|
093be624cc | ||
|
|
deab41ec68 | ||
|
|
ad09224658 | ||
|
|
b9afc19cb4 | ||
|
|
803627f121 | ||
|
|
68380ae11b | ||
|
|
cc97e45a14 | ||
|
|
8e52631d55 | ||
|
|
f4b5a2ee91 | ||
|
|
97f06e9eed | ||
|
|
e358d75adb | ||
|
|
cfff1fc300 | ||
|
|
3980e04d5a | ||
|
|
2496f9c149 | ||
|
|
5207d120ea | ||
|
|
a0101225bc | ||
|
|
a290ce6266 | ||
|
|
a00e47e422 | ||
|
|
750141969c | ||
|
|
a736e6c0ac | ||
|
|
e3e3f8e46a | ||
|
|
f08f20a0e3 | ||
|
|
07eaf919ed | ||
|
|
74d6248f71 | ||
|
|
2ca1161bd7 | ||
|
|
bbeb89d76c | ||
|
|
ff806a110d | ||
|
|
d5003b6e4d | ||
|
|
2635ac76e8 | ||
|
|
70a8309114 | ||
|
|
c91faf997f | ||
|
|
bf76ac77be | ||
|
|
a09a00e502 | ||
|
|
2bacb1eb77 | ||
|
|
d6e7b033a4 | ||
|
|
fa595462ca | ||
|
|
a817a22bc6 | ||
|
|
eff06702b2 | ||
|
|
e77056f9b2 | ||
|
|
935a340292 | ||
|
|
d8794eecd5 | ||
|
|
36a694c965 | ||
|
|
a4701c98f7 | ||
|
|
994118a183 | ||
|
|
c84e6d6db5 | ||
|
|
fa8feaed34 | ||
|
|
846262d787 | ||
|
|
6dcd824fce | ||
|
|
d4b0c22f9e | ||
|
|
e48034dfc9 | ||
|
|
048a490f76 | ||
|
|
db44417b02 | ||
|
|
d05fe1d7da | ||
|
|
0754b7b6fe | ||
|
|
09294365a9 | ||
|
|
63d93d1733 | ||
|
|
c5a3bc39b1 | ||
|
|
9dbb372610 | ||
|
|
228e836344 | ||
|
|
ed23489f42 | ||
|
|
457e2288c9 | ||
|
|
e8ec7ab058 | ||
|
|
1a03cf47f6 | ||
|
|
b97ebdc98f | ||
|
|
2098fd6169 | ||
|
|
ab6120cde5 | ||
|
|
c3c1505392 | ||
|
|
05e141a6b3 | ||
|
|
aab68217b7 | ||
|
|
a95a11e5b8 | ||
|
|
5cbfb18075 | ||
|
|
beb42fffa4 | ||
|
|
660b1b4bdc | ||
|
|
c20c44514a | ||
|
|
6118c043b1 | ||
|
|
5f0ab726f7 | ||
|
|
e82aaf2587 | ||
|
|
27aef3dd91 | ||
|
|
45155597aa | ||
|
|
80afa33aad | ||
|
|
b42c7fa5b8 | ||
|
|
d77599234e | ||
|
|
41a63be28e |
@@ -12,6 +12,8 @@ body:
|
||||
after recreating the CMake build directory and with `-DGGML_CCACHE=OFF`.
|
||||
If the compilation succeeds with ccache disabled you should be able to permanently fix the issue
|
||||
by clearing `~/.cache/ccache` (on Linux).
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
- type: textarea
|
||||
id: commit
|
||||
attributes:
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/011-bug-results.yml
vendored
4
.github/ISSUE_TEMPLATE/011-bug-results.yml
vendored
@@ -1,5 +1,5 @@
|
||||
name: Bug (model use)
|
||||
description: Something goes wrong when using a model (in general, not specific to a single llama.cpp module).
|
||||
description: Something goes wrong when running a model (crashes, garbled outputs, etc.).
|
||||
title: "Eval bug: "
|
||||
labels: ["bug-unconfirmed", "model evaluation"]
|
||||
body:
|
||||
@@ -12,6 +12,8 @@ body:
|
||||
If you encountered the issue while using an external UI (e.g. ollama),
|
||||
please reproduce your issue using one of the examples/binaries in this repository.
|
||||
The `llama-completion` binary can be used for simple and reproducible model inference.
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
- type: textarea
|
||||
id: version
|
||||
attributes:
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/019-bug-misc.yml
vendored
2
.github/ISSUE_TEMPLATE/019-bug-misc.yml
vendored
@@ -10,6 +10,8 @@ body:
|
||||
This issue template is intended for miscellaneous bugs that don't fit into any other category.
|
||||
If you encountered the issue while using an external UI (e.g. ollama),
|
||||
please reproduce your issue using one of the examples/binaries in this repository.
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
- type: textarea
|
||||
id: version
|
||||
attributes:
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/020-enhancement.yml
vendored
2
.github/ISSUE_TEMPLATE/020-enhancement.yml
vendored
@@ -8,6 +8,8 @@ body:
|
||||
value: |
|
||||
[Please post your idea first in Discussion if there is not yet a consensus for this enhancement request. This will help to keep this issue tracker focused on enhancements that the community has agreed needs to be implemented.](https://github.com/ggml-org/llama.cpp/discussions/categories/ideas)
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
|
||||
- type: checkboxes
|
||||
id: prerequisites
|
||||
attributes:
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/030-research.yml
vendored
2
.github/ISSUE_TEMPLATE/030-research.yml
vendored
@@ -8,6 +8,8 @@ body:
|
||||
value: |
|
||||
Don't forget to check for any [duplicate research issue tickets](https://github.com/ggml-org/llama.cpp/issues?q=is%3Aopen+is%3Aissue+label%3A%22research+%F0%9F%94%AC%22)
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
|
||||
- type: checkboxes
|
||||
id: research-stage
|
||||
attributes:
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/040-refactor.yml
vendored
2
.github/ISSUE_TEMPLATE/040-refactor.yml
vendored
@@ -9,6 +9,8 @@ body:
|
||||
Don't forget to [check for existing refactor issue tickets](https://github.com/ggml-org/llama.cpp/issues?q=is%3Aopen+is%3Aissue+label%3Arefactoring) in case it's already covered.
|
||||
Also you may want to check [Pull request refactor label as well](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Aopen+is%3Apr+label%3Arefactoring) for duplicates too.
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
|
||||
- type: textarea
|
||||
id: background-description
|
||||
attributes:
|
||||
|
||||
2
.github/workflows/gguf-publish.yml
vendored
2
.github/workflows/gguf-publish.yml
vendored
@@ -29,10 +29,10 @@ jobs:
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
pip-install: poetry==2.4.0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
cd gguf-py
|
||||
python -m pip install poetry==2.3.2
|
||||
poetry install
|
||||
|
||||
- name: Build package
|
||||
|
||||
2
.github/workflows/python-type-check.yml
vendored
2
.github/workflows/python-type-check.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
pip-install: -r requirements/requirements-all.txt ty==0.0.26
|
||||
pip-install: -r requirements/requirements-all.txt ty==0.0.33
|
||||
# - name: Type-check with Pyright
|
||||
# uses: jakebailey/pyright-action@v2
|
||||
# with:
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -105,6 +105,8 @@
|
||||
__pycache__/
|
||||
*/poetry.lock
|
||||
poetry.toml
|
||||
poetry.lock
|
||||
uv.lock
|
||||
|
||||
# Nix
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ General:
|
||||
- By very precise and concise when writing code, comments, explanations, etc.
|
||||
- PR and commit titles format: `<module> : <title>`. Lookup recents for examples
|
||||
- Don't try to build or run the code unless you are explicitly asked to do so
|
||||
- Use the `gh` CLI tool when querying PRs, issues, or other GitHub resources
|
||||
|
||||
Coding:
|
||||
- When in doubt, always refer to the CONTRIBUTING.md file of the project
|
||||
|
||||
@@ -76,6 +76,7 @@
|
||||
/ggml/src/ggml-vulkan/ @ggml-org/ggml-vulkan
|
||||
/ggml/src/ggml-webgpu/ @ggml-org/ggml-webgpu
|
||||
/ggml/src/ggml-zdnn/ @ggml-org/ggml-zdnn @Andreas-Krebbel @AlekseiNikiforovIBM
|
||||
/ggml/src/ggml-zendnn/ @avinashcpandey @Jiten1parmar @z-vishal
|
||||
/ggml/src/ggml.c @ggerganov
|
||||
/ggml/src/ggml.cpp @ggerganov
|
||||
/ggml/src/gguf.cpp @JohannesGaessler @Green-Sky
|
||||
|
||||
@@ -529,6 +529,7 @@ To learn more about model quantization, [read this documentation](tools/quantize
|
||||
- [How to build](docs/build.md)
|
||||
- [Running on Docker](docs/docker.md)
|
||||
- [Build on Android](docs/android.md)
|
||||
- [Multi-GPU usage](docs/multi-gpu.md)
|
||||
- [Performance troubleshooting](docs/development/token_generation_performance_tips.md)
|
||||
- [GGML tips & tricks](https://github.com/ggml-org/llama.cpp/wiki/GGML-Tips-&-Tricks)
|
||||
|
||||
|
||||
@@ -248,6 +248,8 @@ std::vector<std::string> common_arg::get_env() const {
|
||||
|
||||
// Helper function to parse tensor buffer override strings
|
||||
static void parse_tensor_buffer_overrides(const std::string & value, std::vector<llama_model_tensor_buft_override> & overrides) {
|
||||
ggml_backend_load_all();
|
||||
|
||||
std::map<std::string, ggml_backend_buffer_type_t> buft_list;
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
auto * dev = ggml_backend_dev_get(i);
|
||||
@@ -425,6 +427,10 @@ static bool parse_bool_value(const std::string & value) {
|
||||
}
|
||||
}
|
||||
|
||||
[[noreturn]] static void arg_removed(const std::string & msg) {
|
||||
throw std::invalid_argument("the argument has been removed. " + msg);
|
||||
}
|
||||
|
||||
//
|
||||
// CLI argument parsing functions
|
||||
//
|
||||
@@ -803,6 +809,7 @@ static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & val
|
||||
if (dev_names.size() == 1 && dev_names[0] == "none") {
|
||||
devices.push_back(nullptr);
|
||||
} else {
|
||||
ggml_backend_load_all();
|
||||
for (const auto & device : dev_names) {
|
||||
auto * dev = ggml_backend_dev_by_name(device.c_str());
|
||||
if (!dev || ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) {
|
||||
@@ -820,6 +827,7 @@ static void add_rpc_devices(const std::string & servers) {
|
||||
if (rpc_servers.empty()) {
|
||||
throw std::invalid_argument("no RPC servers specified");
|
||||
}
|
||||
ggml_backend_load_all();
|
||||
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
|
||||
if (!rpc_reg) {
|
||||
throw std::invalid_argument("failed to find RPC backend");
|
||||
@@ -1016,9 +1024,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
|
||||
params.use_color = tty_can_use_colors();
|
||||
|
||||
// load dynamic backends
|
||||
ggml_backend_load_all();
|
||||
|
||||
common_params_context ctx_arg(params);
|
||||
ctx_arg.print_usage = print_usage;
|
||||
ctx_arg.ex = ex;
|
||||
@@ -2275,6 +2280,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--list-devices"},
|
||||
"print list of available devices and exit",
|
||||
[](common_params &) {
|
||||
ggml_backend_load_all();
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
auto * dev = ggml_backend_dev_get(i);
|
||||
@@ -2864,7 +2870,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--tools"}, "TOOL1,TOOL2,...",
|
||||
"experimental: whether to enable built-in tools for AI agents - do not enable in untrusted environments (default: no tools)\n"
|
||||
"specify \"all\" to enable all tools\n"
|
||||
"available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff",
|
||||
"available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff, get_datetime",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.server_tools = parse_csv_row(value);
|
||||
}
|
||||
@@ -3380,7 +3386,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-draft-poll", "--poll-draft"}, "<0|1>",
|
||||
"Use polling to wait for draft model work (default: same as --poll])",
|
||||
"Use polling to wait for draft model work (default: same as --poll)",
|
||||
[](common_params & params, int value) {
|
||||
params.speculative.draft.cpuparams.poll = value;
|
||||
}
|
||||
@@ -3499,7 +3505,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_N_MIN"));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--spec--draft-p-split", "--draft-p-split"}, "P",
|
||||
{"--spec-draft-p-split", "--draft-p-split"}, "P",
|
||||
string_format("speculative decoding split probability (default: %.2f)", (double)params.speculative.draft.p_split),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.draft.p_split = std::stof(value);
|
||||
@@ -3715,35 +3721,35 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--draft", "--draft-n", "--draft-max"}, "N",
|
||||
"the argument has been removed. use --spec-draft-n-max or --spec-ngram-mod-n-max",
|
||||
[](common_params & /*params*/, int /*value*/) {
|
||||
throw std::invalid_argument("the argument has been removed. use --spec-draft-n-max or --spec-ngram-mod-n-max");
|
||||
arg_removed("use --spec-draft-n-max or --spec-ngram-mod-n-max");
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_MAX"));
|
||||
add_opt(common_arg(
|
||||
{"--draft-min", "--draft-n-min"}, "N",
|
||||
"the argument has been removed. use --spec-draft-n-min or --spec-ngram-mod-n-min",
|
||||
[](common_params & /*params*/, int /*value*/) {
|
||||
throw std::invalid_argument("the argument has been removed. use --spec-draft-n-min or --spec-ngram-mod-n-min");
|
||||
arg_removed("use --spec-draft-n-min or --spec-ngram-mod-n-min");
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_MIN"));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-size-n"}, "N",
|
||||
"the argument has been removed. use the respective --spec-ngram-*-size-n or --spec-ngram-mod-n-match",
|
||||
[](common_params & /*params*/, int /*value*/) {
|
||||
throw std::invalid_argument("the argument has been removed. use the respective --spec-ngram-*-size-n");
|
||||
arg_removed("use the respective --spec-ngram-*-size-n");
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-size-m"}, "N",
|
||||
"the argument has been removed. use the respective --spec-ngram-*-size-m",
|
||||
[](common_params & /*params*/, int /*value*/) {
|
||||
throw std::invalid_argument("the argument has been removed. use the respective --spec-ngram-*-size-m");
|
||||
arg_removed("use the respective --spec-ngram-*-size-m");
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-min-hits"}, "N",
|
||||
"the argument has been removed. use the respective --spec-ngram-*-min-hits",
|
||||
[](common_params & /*params*/, int /*value*/) {
|
||||
throw std::invalid_argument("the argument has been removed. use the respective --spec-ngram-*-min-hits");
|
||||
arg_removed("use the respective --spec-ngram-*-min-hits");
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
@@ -3794,7 +3800,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-algorithm"}, "N",
|
||||
string_format("diffusion algorithm: 0=ORIGIN, 1=ENTROPY_BASED, 2=MARGIN_BASED, 3=RANDOM, 4=LOW_CONFIDENCE (default: %d)", params.diffusion.algorithm),
|
||||
string_format(
|
||||
"diffusion algorithm: 0=DIFFUSION_ALGORITHM_ORIGIN, 1=DIFFUSION_ALGORITHM_ENTROPY_BASED, "
|
||||
"2=DIFFUSION_ALGORITHM_MARGIN_BASED, 3=DIFFUSION_ALGORITHM_RANDOM, "
|
||||
"4=DIFFUSION_ALGORITHM_CONFIDENCE_BASED (default: %d)", params.diffusion.algorithm),
|
||||
[](common_params & params, int value) { params.diffusion.algorithm = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
|
||||
@@ -136,10 +136,10 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co
|
||||
if (!end.empty()) {
|
||||
if (!start.empty()) {
|
||||
// Standard tag-based: optional(<think>reasoning</think>)
|
||||
return p.optional(start + p.reasoning(p.until(end)) + end + p.space());
|
||||
return p.optional(p.optspace(start) + p.reasoning(p.until(trim_whitespace(end))) + p.optspace(end));
|
||||
}
|
||||
// Delimiter-style (empty start)
|
||||
return p.optional(p.reasoning(p.until(end)) + end + p.space());
|
||||
return p.optional(p.reasoning(p.until(trim_whitespace(end))) + p.optspace(end));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -186,7 +186,6 @@ common_peg_parser analyze_tools::build_parser(parser_build_context & ctx) const
|
||||
common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
const auto & inputs = ctx.inputs;
|
||||
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
// Build effective field names with dot notation if function_field is set
|
||||
std::string name_field = format.name_field;
|
||||
@@ -225,8 +224,7 @@ common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_cont
|
||||
tool_start = format.per_call_start;
|
||||
}
|
||||
|
||||
return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(p.until(tool_start)))) + tools_parser +
|
||||
p.end();
|
||||
return ctx.reasoning_parser + p.optional(p.content(p.until(tool_start))) + tools_parser + p.end();
|
||||
}
|
||||
|
||||
common_peg_parser analyze_tools::build_func_parser(common_chat_peg_builder & p, const std::string & name,
|
||||
@@ -270,7 +268,6 @@ common_peg_parser analyze_tools::build_func_parser(common_chat_peg_builder & p,
|
||||
common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
const auto & inputs = ctx.inputs;
|
||||
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
common_peg_parser tool_choice = p.choice();
|
||||
|
||||
@@ -336,14 +333,12 @@ common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context
|
||||
|
||||
std::string trigger_marker = !format.section_start.empty() ? format.section_start : format.per_call_start;
|
||||
auto content_before_tools = trigger_marker.empty() ? p.eps() : p.until(trigger_marker);
|
||||
return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(content_before_tools))) + tool_calls +
|
||||
p.end();
|
||||
return ctx.reasoning_parser + p.optional(p.content(content_before_tools)) + tool_calls + p.end();
|
||||
}
|
||||
|
||||
common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
const auto & inputs = ctx.inputs;
|
||||
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
auto until_suffix = p.rule("until-suffix", p.until(arguments.value_suffix));
|
||||
|
||||
@@ -374,9 +369,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
arguments.name_suffix) +
|
||||
arguments.value_prefix +
|
||||
(schema_info.resolves_to_string(param_schema) ?
|
||||
p.tool_arg_string_value(p.schema(until_suffix,
|
||||
"tool-" + name + "-arg-" + param_name + "-schema",
|
||||
param_schema, true)) :
|
||||
p.tool_arg_string_value(until_suffix) :
|
||||
p.tool_arg_json_value(p.schema(
|
||||
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
|
||||
p.space()) +
|
||||
@@ -471,8 +464,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
|
||||
std::string trigger_marker = !format.section_start.empty() ? format.section_start : format.per_call_start;
|
||||
auto content_before_tools = trigger_marker.empty() ? p.eps() : p.until(trigger_marker);
|
||||
return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(content_before_tools))) + tool_calls +
|
||||
p.end();
|
||||
return ctx.reasoning_parser + p.optional(p.content(content_before_tools)) + tool_calls + p.end();
|
||||
}
|
||||
|
||||
} // namespace autoparser
|
||||
|
||||
@@ -342,7 +342,7 @@ void analyze_reasoning::compare_thinking_enabled() {
|
||||
if (left_trimmed.empty() && !diff.right.empty()) {
|
||||
if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) {
|
||||
if (start.empty()) {
|
||||
start = trim_leading_whitespace(diff.right);
|
||||
start = diff.right;
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
}
|
||||
}
|
||||
@@ -353,7 +353,7 @@ void analyze_reasoning::compare_thinking_enabled() {
|
||||
if (seg.size() >= 2 && seg[seg.size() - 1].value == left_trimmed && seg[seg.size() - 2].type == segment_type::MARKER) {
|
||||
start = seg[seg.size() - 2].value;
|
||||
}
|
||||
end = trim_trailing_whitespace(diff.left);
|
||||
end = diff.left;
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
}
|
||||
}
|
||||
@@ -445,14 +445,14 @@ void analyze_reasoning::compare_reasoning_scope() {
|
||||
auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B);
|
||||
if (result.result.success()) {
|
||||
start = result.tags["pre"];
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
end = result.tags["post"];
|
||||
} else {
|
||||
auto parser_delimiter = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space())));
|
||||
});
|
||||
result = parser_delimiter.parse_anywhere_and_extract(comparison->output_B);
|
||||
if (result.result.success()) {
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
end = result.tags["post"];
|
||||
} else {
|
||||
LOG_DBG(ANSI_ORANGE "%s: Unable to extract reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__);
|
||||
mode = reasoning_mode::NONE;
|
||||
|
||||
@@ -816,6 +816,32 @@ common_peg_parser common_chat_peg_builder::prefix(const std::string & s, const s
|
||||
return literal(s.substr(0, s.rfind(delimiter)));
|
||||
}
|
||||
|
||||
common_peg_parser common_chat_peg_builder::optspace(const std::string & tag) {
|
||||
auto parser = eps();
|
||||
size_t end_of_prefix_space = tag.size();
|
||||
size_t start_of_suffix_space = tag.size();
|
||||
for (size_t i = 0; i < tag.size(); i++) {
|
||||
if (!std::isspace(tag[i])) {
|
||||
end_of_prefix_space = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (size_t i = tag.size(); i > 0; i--) {
|
||||
if (!std::isspace(tag[i - 1])) {
|
||||
start_of_suffix_space = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < end_of_prefix_space; i++) {
|
||||
parser += optional(literal(std::string(1, tag[i])));
|
||||
}
|
||||
parser += literal(tag.substr(end_of_prefix_space, start_of_suffix_space - end_of_prefix_space));
|
||||
for (size_t i = start_of_suffix_space; i < tag.size(); i++) {
|
||||
parser += optional(literal(std::string(1, tag[i])));
|
||||
}
|
||||
return parser;
|
||||
}
|
||||
|
||||
common_peg_parser common_chat_peg_builder::standard_json_tools(
|
||||
const std::string & section_start,
|
||||
const std::string & section_end,
|
||||
|
||||
@@ -96,6 +96,9 @@ class common_chat_peg_builder : public common_peg_parser_builder {
|
||||
// Return a parser that parses the prefix of a string, up to a given delimiter.
|
||||
common_peg_parser prefix(const std::string & s, const std::string & delimiter = {});
|
||||
|
||||
// Return a parser that parses all elements of tag, but leading and trailing spaces are optional
|
||||
common_peg_parser optspace(const std::string & tag);
|
||||
|
||||
// Legacy-compatible helper for building standard JSON tool calls
|
||||
// Used by tests and manual parsers
|
||||
// name_key/args_key: JSON key names for function name and arguments
|
||||
|
||||
@@ -80,7 +80,7 @@ json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const {
|
||||
if (!content.empty()) {
|
||||
jmsg["content"] = content;
|
||||
} else if (!content_parts.empty()) {
|
||||
if (concat_typed_text) {
|
||||
if (concat_typed_text || contains_media()) {
|
||||
std::string text;
|
||||
bool last_was_media_marker = false;
|
||||
// join parts with newline, do not add newline before or after media markers
|
||||
@@ -2116,22 +2116,38 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::string common_chat_templates_generation_prompt(const common_chat_template & tmpl, const autoparser::generation_params & inputs) {
|
||||
autoparser::generation_params params = inputs;
|
||||
params.add_generation_prompt = false;
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
params.add_generation_prompt = true;
|
||||
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
|
||||
size_t prefix_len = 0;
|
||||
size_t min_size = std::min(no_gen_prompt.size(), gen_prompt.size());
|
||||
while (prefix_len < min_size && no_gen_prompt[prefix_len] == gen_prompt[prefix_len]) {
|
||||
prefix_len++;
|
||||
}
|
||||
return gen_prompt.substr(prefix_len);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs) {
|
||||
autoparser::generation_params params;
|
||||
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
|
||||
const auto & tmpl =
|
||||
params.tools.is_array() && tmpls->template_tool_use ? *tmpls->template_tool_use : *tmpls->template_default;
|
||||
const auto & src = tmpl.source();
|
||||
const auto & caps = tmpl.original_caps();
|
||||
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.reasoning_format = inputs.reasoning_format;
|
||||
params.enable_thinking = inputs.enable_thinking;
|
||||
params.grammar = inputs.grammar;
|
||||
params.now = inputs.now;
|
||||
params.add_bos = tmpls->add_bos;
|
||||
params.add_eos = tmpls->add_eos;
|
||||
const auto & src = tmpl.source();
|
||||
const auto & caps = tmpl.original_caps();
|
||||
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.reasoning_format = inputs.reasoning_format;
|
||||
params.enable_thinking = inputs.enable_thinking;
|
||||
params.grammar = inputs.grammar;
|
||||
params.now = inputs.now;
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
params.add_bos = tmpls->add_bos;
|
||||
params.add_eos = tmpls->add_eos;
|
||||
|
||||
if (src.find("<|channel|>") == std::string::npos) {
|
||||
// map developer to system for all models except for GPT-OSS
|
||||
@@ -2153,14 +2169,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
workaround::func_args_not_string(params.messages);
|
||||
}
|
||||
|
||||
params.add_generation_prompt = false;
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
params.add_generation_prompt = true;
|
||||
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
|
||||
params.generation_prompt = diff.right + diff.suffix;
|
||||
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
params.generation_prompt = common_chat_templates_generation_prompt(tmpl, params);
|
||||
|
||||
params.extra_context = common_chat_extra_context();
|
||||
for (auto el : inputs.chat_template_kwargs) {
|
||||
@@ -2212,8 +2221,8 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
|
||||
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
|
||||
if (auto_params.supports_thinking) {
|
||||
auto_params.thinking_start_tag = autoparser.reasoning.start;
|
||||
auto_params.thinking_end_tag = autoparser.reasoning.end;
|
||||
auto_params.thinking_start_tag = trim_whitespace(autoparser.reasoning.start);
|
||||
auto_params.thinking_end_tag = trim_whitespace(autoparser.reasoning.end);
|
||||
}
|
||||
auto_params.generation_prompt = params.generation_prompt;
|
||||
common_peg_arena arena;
|
||||
|
||||
@@ -94,6 +94,15 @@ struct common_chat_msg {
|
||||
tool_name.empty() && tool_call_id.empty();
|
||||
}
|
||||
|
||||
bool contains_media() const {
|
||||
for (const auto & part : content_parts) {
|
||||
if (part.type == "media_marker") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void set_tool_call_ids(std::vector<std::string> & ids_cache,
|
||||
const std::function<std::string()> & gen_tool_call_id) {
|
||||
for (auto i = 0u; i < tool_calls.size(); i++) {
|
||||
|
||||
@@ -109,16 +109,24 @@ static std::vector<llama_device_memory_data> common_get_device_memory_data(
|
||||
ret.back().total = total;
|
||||
}
|
||||
for (size_t i = 0; i < nd; i++) {
|
||||
ggml_backend_dev_t dev = llama_model_get_device(model, i);
|
||||
|
||||
size_t free;
|
||||
size_t total;
|
||||
ggml_backend_dev_memory(llama_model_get_device(model, i), &free, &total);
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
|
||||
// devices can return 0 bytes for free and total memory if they do not
|
||||
// have any to report. in this case, we will use the host memory as a fallback
|
||||
// fixes: https://github.com/ggml-org/llama.cpp/issues/18577
|
||||
// Some non-GPU accelerator backends, such as BLAS, report 0/0 and rely on
|
||||
// the host-memory fallback. For GPU-like backends, keep 0/0 so --fit does
|
||||
// not assign anything to a device with an unknown memory budget.
|
||||
if (free == 0 && total == 0) {
|
||||
free = ret.back().free;
|
||||
total = ret.back().total;
|
||||
const enum ggml_backend_dev_type type = ggml_backend_dev_type(dev);
|
||||
if (type == GGML_BACKEND_DEVICE_TYPE_GPU || type == GGML_BACKEND_DEVICE_TYPE_IGPU) {
|
||||
LOG_WRN("%s: device %s did not report memory; --fit will not use it\n",
|
||||
__func__, ggml_backend_dev_name(dev));
|
||||
} else {
|
||||
free = ret.back().free;
|
||||
total = ret.back().total;
|
||||
}
|
||||
}
|
||||
ret[i].free = free;
|
||||
ret[i].total = total;
|
||||
|
||||
@@ -57,7 +57,7 @@ static fs::path get_cache_directory() {
|
||||
#ifndef _WIN32
|
||||
const struct passwd * pw = getpwuid(getuid());
|
||||
|
||||
if (pw->pw_dir && *pw->pw_dir) {
|
||||
if (pw && pw->pw_dir && *pw->pw_dir) {
|
||||
return fs::path(pw->pw_dir) / ".cache" / "huggingface" / "hub";
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -232,34 +232,6 @@ static struct llama_sampler * common_reasoning_budget_init_state(
|
||||
);
|
||||
}
|
||||
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
const std::vector<llama_token> & prefill_tokens) {
|
||||
// Determine initial state from prefill: COUNTING if the prefill begins with
|
||||
// the start sequence but does not also contain the end sequence after it.
|
||||
common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE;
|
||||
if (!prefill_tokens.empty() && !start_tokens.empty() &&
|
||||
prefill_tokens.size() >= start_tokens.size() &&
|
||||
std::equal(start_tokens.begin(), start_tokens.end(), prefill_tokens.begin())) {
|
||||
initial_state = REASONING_BUDGET_COUNTING;
|
||||
// If the end sequence also follows the start in the prefill, reasoning
|
||||
// was opened and immediately closed — stay IDLE.
|
||||
if (!end_tokens.empty() &&
|
||||
prefill_tokens.size() >= start_tokens.size() + end_tokens.size()) {
|
||||
auto end_start = prefill_tokens.end() - (ptrdiff_t) end_tokens.size();
|
||||
if (end_start >= prefill_tokens.begin() + (ptrdiff_t) start_tokens.size() &&
|
||||
std::equal(end_tokens.begin(), end_tokens.end(), end_start)) {
|
||||
initial_state = REASONING_BUDGET_IDLE;
|
||||
}
|
||||
}
|
||||
}
|
||||
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
|
||||
}
|
||||
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
|
||||
@@ -29,10 +29,7 @@ enum common_reasoning_budget_state {
|
||||
// end_tokens - token sequence for natural deactivation
|
||||
// forced_tokens - token sequence forced when budget expires
|
||||
// budget - max tokens allowed in the reasoning block
|
||||
// prefill_tokens - tokens already present in the prompt (generation prompt);
|
||||
// used to determine the initial state: COUNTING if they begin
|
||||
// with start_tokens (but don't also end with end_tokens),
|
||||
// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING.
|
||||
// initial_state - initial state
|
||||
//
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
@@ -40,16 +37,6 @@ struct llama_sampler * common_reasoning_budget_init(
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
const std::vector<llama_token> & prefill_tokens = {});
|
||||
|
||||
// Variant that takes an explicit initial state (used by tests and clone).
|
||||
// COUNTING with budget <= 0 is promoted to FORCING.
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
common_reasoning_budget_state initial_state);
|
||||
common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE);
|
||||
|
||||
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl);
|
||||
|
||||
@@ -260,32 +260,35 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
}
|
||||
}
|
||||
|
||||
// Compute prefill tokens from the generation prompt
|
||||
std::vector<llama_token> prefill_tokens;
|
||||
if (!params.generation_prompt.empty()) {
|
||||
GGML_ASSERT(vocab != nullptr);
|
||||
auto tokens = common_tokenize(vocab, params.generation_prompt, false, true);
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
std::string piece = common_token_to_piece(vocab, tokens[i], true);
|
||||
if (i == 0 && std::isspace(piece[0]) && !std::isspace(params.generation_prompt[0])) {
|
||||
// Some tokenizers will add a space before the first special token, need to exclude
|
||||
continue;
|
||||
}
|
||||
LOG_DBG("%s: prefill token: %d = %s\n", __func__, tokens[i], piece.c_str());
|
||||
prefill_tokens.push_back(tokens[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Feed generation prompt tokens to the grammar sampler so it advances past
|
||||
// tokens the template already placed in the prompt.
|
||||
// Only applies to output-format and tool-call grammars; user-supplied grammars must not be prefilled.
|
||||
std::vector<llama_token> prefill_tokens;
|
||||
if (!params.generation_prompt.empty() && common_grammar_needs_prefill(params.grammar)) {
|
||||
GGML_ASSERT(vocab != nullptr);
|
||||
prefill_tokens = common_tokenize(vocab, params.generation_prompt, false, true);
|
||||
if (!prefill_tokens.empty()) {
|
||||
std::string first_token = common_token_to_piece(vocab, prefill_tokens[0], true);
|
||||
if (std::isspace(first_token[0]) && !std::isspace(params.generation_prompt[0])) {
|
||||
// Some tokenizers will add a space before the first special token, need to remove
|
||||
prefill_tokens = std::vector<llama_token>(prefill_tokens.begin() + 1, prefill_tokens.end());
|
||||
}
|
||||
}
|
||||
|
||||
if (grmr && !params.grammar_lazy) {
|
||||
try {
|
||||
for (const auto & token : prefill_tokens) {
|
||||
llama_sampler_accept(grmr, token);
|
||||
LOG_DBG("%s: accepted prefill token (%d)\n", __func__, token);
|
||||
}
|
||||
} catch (std::exception &e) {
|
||||
LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__,
|
||||
common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str());
|
||||
throw e;
|
||||
if (grmr && !params.grammar_lazy && common_grammar_needs_prefill(params.grammar)) {
|
||||
try {
|
||||
for (const auto & token : prefill_tokens) {
|
||||
llama_sampler_accept(grmr, token);
|
||||
LOG_DBG("%s: grammar accepted prefill token (%d)\n", __func__, token);
|
||||
}
|
||||
} catch (std::exception &e) {
|
||||
LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__,
|
||||
common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -296,8 +299,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
params.reasoning_budget_start,
|
||||
params.reasoning_budget_end,
|
||||
params.reasoning_budget_forced,
|
||||
params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens,
|
||||
prefill_tokens);
|
||||
params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens);
|
||||
|
||||
for (const auto & token : prefill_tokens) {
|
||||
llama_sampler_accept(rbudget, token);
|
||||
LOG_DBG("%s: reasoning-budget accepted prefill token (%d)\n", __func__, token);
|
||||
}
|
||||
}
|
||||
|
||||
if (params.has_logit_bias()) {
|
||||
@@ -431,7 +438,7 @@ static bool grammar_should_apply(struct common_sampler * gsmpl) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool is_generated) {
|
||||
if (!gsmpl) {
|
||||
return;
|
||||
}
|
||||
@@ -439,9 +446,11 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
// grammar_should_apply() checks the reasoning budget state, so calculate this before we accept
|
||||
accept_grammar = accept_grammar && grammar_should_apply(gsmpl);
|
||||
const auto accept_grammar = is_generated && grammar_should_apply(gsmpl);
|
||||
|
||||
llama_sampler_accept(gsmpl->rbudget, token);
|
||||
if (gsmpl->rbudget && is_generated) {
|
||||
llama_sampler_accept(gsmpl->rbudget, token);
|
||||
}
|
||||
|
||||
if (gsmpl->grmr && accept_grammar) {
|
||||
llama_sampler_accept(gsmpl->grmr, token);
|
||||
|
||||
@@ -41,8 +41,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
|
||||
void common_sampler_free(struct common_sampler * gsmpl);
|
||||
|
||||
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar);
|
||||
// if is_generated is true, the token is accepted by the sampling chain, the reasoning budget sampler, and the grammar sampler
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool is_generated);
|
||||
void common_sampler_reset (struct common_sampler * gsmpl);
|
||||
struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
|
||||
|
||||
|
||||
@@ -167,8 +167,6 @@ struct common_speculative_checkpoint {
|
||||
size_t size() const {
|
||||
return data.size();
|
||||
}
|
||||
|
||||
size_t ckpt_size = 0;
|
||||
};
|
||||
|
||||
struct common_speculative_state_draft : public common_speculative_state {
|
||||
@@ -176,7 +174,7 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
llama_context * ctx_dft;
|
||||
|
||||
bool use_ckpt = false;
|
||||
struct common_speculative_checkpoint ckpt;
|
||||
common_speculative_checkpoint ckpt;
|
||||
|
||||
common_sampler * smpl;
|
||||
|
||||
@@ -249,29 +247,19 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
|
||||
void begin(const llama_tokens & prompt) override {
|
||||
if (use_ckpt && ckpt.size() > 0) {
|
||||
// delete checkpoint
|
||||
LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n",
|
||||
__func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024);
|
||||
ckpt.pos_min = 0;
|
||||
ckpt.pos_max = 0;
|
||||
ckpt.n_tokens = 0;
|
||||
ckpt.ckpt_size = 0;
|
||||
ckpt.data.clear();
|
||||
}
|
||||
void begin(const llama_tokens & /*prompt*/) override {
|
||||
}
|
||||
|
||||
size_t draft_create_checkpoint(int n_tokens_prompt, int n_tokens_batch) {
|
||||
size_t create_checkpoint(int n_tokens_prompt) {
|
||||
int slot_id = 0;
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
|
||||
ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id);
|
||||
ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id);
|
||||
ckpt.n_tokens = n_tokens_prompt - n_tokens_batch;
|
||||
ckpt.n_tokens = n_tokens_prompt;
|
||||
ckpt.data.resize(checkpoint_size);
|
||||
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (n != checkpoint_size) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
|
||||
}
|
||||
@@ -281,13 +269,13 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
return n;
|
||||
}
|
||||
|
||||
size_t draft_restore_checkpoint(size_t ckpt_size_part_expected) {
|
||||
size_t restore_checkpoint() {
|
||||
int slot_id = 0;
|
||||
LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max);
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
if (n != ckpt_size_part_expected) {
|
||||
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
|
||||
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n);
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (n != ckpt.size()) {
|
||||
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu",
|
||||
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size());
|
||||
}
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1);
|
||||
|
||||
@@ -346,13 +334,18 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
|
||||
const int i_start = std::max<int>(0, (int) prompt_cur.size() - n_ctx);
|
||||
|
||||
if (use_ckpt && i_start > 0) {
|
||||
LOG_WRN("%s: context shift is not supported with checkpoint-based contexts - skipping\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
// reuse as much as possible from the old draft context
|
||||
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
|
||||
for (int i = 0; i < (int) prompt_dft.size(); ++i) {
|
||||
int cur = 0;
|
||||
while (i_start + cur < (int) prompt_cur.size() &&
|
||||
i + cur < (int) prompt_dft.size() &&
|
||||
prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
|
||||
i + cur < (int) prompt_dft.size() &&
|
||||
prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
|
||||
cur++;
|
||||
}
|
||||
|
||||
@@ -360,21 +353,26 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
reuse_i = i;
|
||||
reuse_n = cur;
|
||||
}
|
||||
|
||||
if (use_ckpt) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n",
|
||||
__func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size());
|
||||
if (use_ckpt && ckpt.ckpt_size == 0 && reuse_n > 0) {
|
||||
LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n",
|
||||
__func__, reuse_i, reuse_n);
|
||||
if (use_ckpt && ckpt.n_tokens > reuse_n) {
|
||||
LOG_DBG("%s: checkpoint (n_tokens = %d) is outdated -> delete it\n", __func__, (int) ckpt.n_tokens);
|
||||
|
||||
reuse_i = 0;
|
||||
reuse_n = 0;
|
||||
|
||||
ckpt = {};
|
||||
}
|
||||
|
||||
result.clear();
|
||||
result.reserve(sparams.n_max);
|
||||
|
||||
bool needs_ckpt = use_ckpt && prompt_dft.size() > 0;
|
||||
if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) {
|
||||
llama_memory_clear(mem_dft, false);
|
||||
prompt_dft.clear();
|
||||
@@ -393,50 +391,38 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
return;
|
||||
}
|
||||
|
||||
bool do_restore = false;
|
||||
if (prompt_dft.size() > prompt_cur.size() && reuse_i + reuse_n < (int64_t) prompt_dft.size()) {
|
||||
// This can happen after a partial acceptance (speculative decoding with checkpoints)
|
||||
LOG_DBG("%s: #prompt_dft=%zu, #prompt_cur=%zu, shorten draft\n",
|
||||
__func__, prompt_dft.size(), prompt_cur.size());
|
||||
prompt_dft.resize(prompt_cur.size());
|
||||
do_restore = true;
|
||||
}
|
||||
|
||||
if (reuse_i > 0) {
|
||||
GGML_ASSERT(!use_ckpt);
|
||||
|
||||
bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
|
||||
if (!is_removed) {
|
||||
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i);
|
||||
return;
|
||||
}
|
||||
llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
|
||||
|
||||
prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
|
||||
}
|
||||
|
||||
if (reuse_n < (int) prompt_dft.size() || do_restore) {
|
||||
if (reuse_n < (int) prompt_dft.size()) {
|
||||
if (use_ckpt) {
|
||||
if (ckpt.n_tokens > (int64_t) prompt_dft.size()) {
|
||||
LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n",
|
||||
__func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size());
|
||||
if (ckpt.n_tokens > 0) {
|
||||
LOG_DBG("%s: restoring checkpoint, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size());
|
||||
restore_checkpoint();
|
||||
reuse_n = ckpt.n_tokens;
|
||||
prompt_dft.resize(reuse_n);
|
||||
}
|
||||
draft_restore_checkpoint(ckpt.ckpt_size);
|
||||
reuse_n = ckpt.n_tokens;
|
||||
prompt_dft.resize(reuse_n);
|
||||
needs_ckpt = false;
|
||||
} else {
|
||||
bool is_removed = llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
|
||||
const bool is_removed = llama_memory_seq_rm(mem_dft, 0, reuse_n, -1);
|
||||
if (!is_removed) {
|
||||
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n",
|
||||
__func__, reuse_n, prompt_dft.size());
|
||||
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size());
|
||||
return;
|
||||
}
|
||||
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (needs_ckpt) {
|
||||
ckpt.ckpt_size = draft_create_checkpoint(prompt_dft.size(), batch.n_tokens);
|
||||
}
|
||||
|
||||
// prepare a batch to evaluate any new tokens in the prompt
|
||||
common_batch_clear(batch);
|
||||
|
||||
@@ -450,12 +436,17 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
// we should rarely end-up here during normal decoding
|
||||
if (batch.n_tokens > 0) {
|
||||
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
|
||||
LOG_DBG("%s: draft prompt batch: %d tokens\n", __func__, batch.n_tokens);
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0 && ret != 1) {
|
||||
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n",
|
||||
__func__, ret, prompt_cur.size());
|
||||
}
|
||||
|
||||
if (use_ckpt) {
|
||||
create_checkpoint(prompt_dft.size());
|
||||
}
|
||||
}
|
||||
|
||||
const llama_pos n_past = prompt_dft.size();
|
||||
@@ -784,17 +775,15 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
||||
}
|
||||
|
||||
void accept(uint16_t n_accepted) override {
|
||||
if (verbose) {
|
||||
LOG_INF("%s: accepted %d tokens from %zu drafted tokens\n", __func__, n_accepted, n_draft_last);
|
||||
}
|
||||
|
||||
// compute acceptance fraction if we have a recorded draft length
|
||||
if (n_draft_last > 0) {
|
||||
const double f_acc = (double)n_accepted / (double)n_draft_last;
|
||||
if (f_acc < 0.5) {
|
||||
n_low++;
|
||||
if (n_low >= 3) {
|
||||
LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low);
|
||||
if (verbose) {
|
||||
LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low);
|
||||
}
|
||||
|
||||
mod.reset();
|
||||
n_low = 0;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -175,6 +175,7 @@ pre_computed_hashes = [
|
||||
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
|
||||
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
|
||||
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
|
||||
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openbmb/MiniCPM-V-4_6", "chkhsh": "1444df51289cfa8063b96f0e62b1125440111bc79a52003ea14b6eac7016fd5f"},
|
||||
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
|
||||
# jina-v2-de variants
|
||||
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},
|
||||
|
||||
@@ -737,6 +737,14 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
||||
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
|
||||
| UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.|
|
||||
|
||||
## Compile-time Flags
|
||||
|
||||
Pass these via `CXXFLAGS` or add a one-off `#define` to enable a flag on the spot.
|
||||
|
||||
| Name | Function |
|
||||
|-----------------|----------------------------------------------------------------------------------|
|
||||
| DEBUG_SYCL_POOL | Enable device memory pool logging on teardown. Useful for profiling allocations. |
|
||||
|
||||
## Design Rule
|
||||
|
||||
- Open to all contributors.
|
||||
|
||||
127
docs/multi-gpu.md
Normal file
127
docs/multi-gpu.md
Normal file
@@ -0,0 +1,127 @@
|
||||
# Using multiple GPUs with llama.cpp
|
||||
|
||||
This guide explains how to run [llama.cpp](https://github.com/ggml-org/llama.cpp) across more than one GPU. It covers the split modes, the command-line flags that control them, the limitations you need to know about, and ready-to-use recipes for `llama-cli` and `llama-server`.
|
||||
|
||||
The CLI arguments listed here are the same for both tools - or most llama.cpp binaries for that matter.
|
||||
|
||||
---
|
||||
|
||||
## When you need multi-GPU
|
||||
|
||||
Reach for multi-GPU when one of these is true:
|
||||
|
||||
- **The model doesn't fit in a single GPU's VRAM.** By spreading the weights across two or more GPUs the whole model can stay on accelerators. Otherwise part of the model will need to be run off of the comparatively slower system RAM.
|
||||
- **You want more throughput.** By distributing the computation across multiple GPUs, each individual GPU has to do less work. This can result in better prefill and/or token generation performance, depending on the split mode and interconnect speed vs. the speed of an individual GPU.
|
||||
|
||||
---
|
||||
|
||||
## The split modes
|
||||
|
||||
Set with `--split-mode` / `-sm`.
|
||||
|
||||
| Mode | What it does | When to use |
|
||||
|---|---|---|
|
||||
| `none` | Use a single GPU only. Pick which one with `--main-gpu`. | You explicitly want to confine the model to one GPU even though more are visible. |
|
||||
| `layer` (**default**) | Pipeline parallelism. Each GPU holds a contiguous slice of layers. The KV cache for layer *l* lives on the GPU that owns layer *l*. | Default and most compatible multi-GPU choice. You want more memory than a single GPU provides and your priority is a fast prefill. Can tolerate slow interconnect speeds between GPUs. |
|
||||
| `row` | **Deprecated.** Older row-split tensor-parallel path with comparatively poor performance. Splits only dense weights across GPUs. Superseded by `tensor` which should be universally superior if it can be used. | Avoid in new deployments. |
|
||||
| `tensor` | **EXPERIMENTAL.** Tensor parallelism that splits both weights *and* KV across the participating GPUs via a "meta device" abstraction. | You want more memory than a single GPU provides and your priority is fast token generation. Prefill speeds approach pipeline parallel speeds for large, dense models and fast GPU interconnect speeds. Treat as experimental as the code is less mature than pipeline parallelism. Performance should be good for multiple NVIDIA GPUs using the CUDA backend, no guarantees otherwise. |
|
||||
|
||||
> Pipeline parallel (`layer`) vs. tensor parallel (`tensor`): pipeline-parallel runs different layers on different GPUs and processes tokens sequentially through the pipeline. This minimizes data transfers between GPUs but requires many tokens to scale well. Tensor-parallel splits each layer across GPUs and does multiple cross-GPU reductions per layer. This enables parallelizing any workload but is much more bottlenecked by the GPU interconnect speed. Pipeline-parallel maximizes batch throughput; tensor-parallel minimizes latency.
|
||||
|
||||
---
|
||||
|
||||
## Command-line arguments reference
|
||||
|
||||
| Short | Long | Value | Default | Notes |
|
||||
|---|---|---|---|---|
|
||||
| `-sm` | `--split-mode` | `none` \| `layer` \| `tensor` | `layer` | See modes above. |
|
||||
| `-ts` | `--tensor-split` | comma-separated proportions, e.g. `3,1` | mode-dependent | How much of the model goes to each GPU. If omitted, `layer`/`row` use automatic splitting proportional to memory, while `tensor` splits tensor segments evenly. With `3,1` on two GPUs, GPU 0 gets 75 %, GPU 1 gets 25 %. The values follow the order in `--device`. |
|
||||
| `-mg` | `--main-gpu` | integer device index | `0` | The single GPU used in `--split-mode none`. |
|
||||
| `-ngl` | `--n-gpu-layers` / `--gpu-layers` | integer \| `auto` \| `all` | `auto` | Maximum number of layers to keep in VRAM. Use `999` or `all` to push everything possible to the GPUs. |
|
||||
| `-dev` | `--device` | comma-separated device names, or `none` | auto | Restrict which devices llama.cpp may use. See `--list-devices` for names. |
|
||||
| | `--list-devices` | - | - | Print the available devices and their memory. Run this first to learn the names you'd pass to `--device`. |
|
||||
| `-fa` | `--flash-attn` | `on` \| `off` \| `auto` | `auto` | Required when using `--split-mode tensor` and/or quantized V cache. Supported (and therefore enabled by default) for most combinations of models and backends. |
|
||||
| `-ctk` | `--cache-type-k` | `f32` \| `f16` \| `bf16` \| `q8_0` \| `q4_0` \| ... | `f16` | KV cache type for K. |
|
||||
| `-ctv` | `--cache-type-v` | same as `-ctk` | `f16` | KV cache type for V. |
|
||||
| `-fit` | `--fit` | `on` \| `off` | `on` | Auto-fit unset args to device memory. **Not supported with `tensor`. You may need to manually set the `--ctx-size` to make the model fit.** |
|
||||
|
||||
As for any CUDA program, the environment variable `CUDA_VISIBLE_DEVICES` can be used to control which GPUs to use for the CUDA backend: if you set it, llama.cpp only sees the specified GPUs. Use `--device` for selecting GPUs from among those visible to llama.cpp, this works for any backend.
|
||||
|
||||
---
|
||||
|
||||
## Recipes
|
||||
|
||||
### 1. Default - pipeline parallel across all visible GPUs
|
||||
|
||||
```bash
|
||||
llama-cli -m model.gguf
|
||||
llama-server -m model.gguf
|
||||
```
|
||||
|
||||
Easiest configuration. KV cache spreads across the GPUs along with the layers. `--fit` (on by default) sizes things automatically.
|
||||
|
||||
### 2. Pipeline parallel with a custom split ratio
|
||||
|
||||
```bash
|
||||
llama-cli -m model.gguf -ts 3,1
|
||||
```
|
||||
|
||||
Useful when GPUs have different memory: GPU 0 (3 parts) and GPU 1 (1 part). Proportions are normalized so `-ts 3,1` is the same as e.g. `-ts 75,25`.
|
||||
|
||||
### 3. Single-GPU mode, picking a specific GPU
|
||||
|
||||
```bash
|
||||
llama-cli --list-devices
|
||||
llama-cli -m model.gguf -dev CUDA1
|
||||
```
|
||||
|
||||
Use only the device listed as `CUDA1` when calling with `--list-devices`.
|
||||
|
||||
### 4. Tensor parallelism (experimental)
|
||||
|
||||
```bash
|
||||
llama-cli -m model.gguf -sm tensor -ctk f16 -ctv f16
|
||||
```
|
||||
|
||||
- `--flash-attn off` or (`--flash-attn auto` resolving to `off` when it isn't supported) is a hard error.
|
||||
- KV cache types must be non-quantized: `f32`, `f16`, or `bf16`. Support for quantized KV cache is not implemented and trying to use it will result in an error.
|
||||
- Mark this configuration as experimental in your tooling: validate output quality before deploying.
|
||||
- `--split-mode tensor`is not implemented for all architectures. The following will fail with *"LLAMA_SPLIT_MODE_TENSOR not implemented for architecture '...'"*:
|
||||
|
||||
- **MoE / hybrid:** Grok, MPT, OLMoE, DeepSeek2, GLM-DSA, Nemotron-H, Nemotron-H-MoE, Granite-Hybrid, LFM2-MoE, Minimax-M2, Mistral4, Kimi-Linear, Jamba, Falcon-H1
|
||||
- **State-space / RWKV-style:** Mamba, Mamba2 (and the hybrid Mamba-attention models above)
|
||||
- **Other:** PLAMO2, MiniCPM3, Gemma-3n, OLMo2, BitNet, T5
|
||||
|
||||
### 5. With NCCL
|
||||
|
||||
There's no runtime flag for NCCL - it's selected at build time (`-DGGML_CUDA_NCCL=ON`, this is the default). Note that NCCL is **not** automatically distributed with CUDA and you may need to install it manually - when in doubt check the CMake log to see whether or not it can find the package. When llama.cpp is compiled with NCCL support it uses it automatically for cross-GPU reductions in `tensor` mode. When NCCL is missing on a multi-GPU build, you'll see this one-time warning and performance will be lower:
|
||||
|
||||
```
|
||||
NVIDIA Collective Communications Library (NCCL) is unavailable, multi GPU performance will be suboptimal
|
||||
```
|
||||
|
||||
When using the "ROCm" backend (which is the ggml CUDA code translated for AMD via HIP), the AMD equivalent RCCL can be used by compiling with `-DGGML_HIP_RCCL=ON`. Note that RCCL is by default *disabled* because (unlike NCCL) it was not universally beneficial during testing.
|
||||
### 6. With CUDA peer-to-peer access (`GGML_CUDA_P2P`)
|
||||
|
||||
CUDA peer-to-peer (P2P) lets GPUs transfer data directly between each other instead of going through system memory, which generally improves multi-GPU performance. It is **opt-in** at runtime - set the environment variable `GGML_CUDA_P2P` to any value to enable it:
|
||||
|
||||
```bash
|
||||
GGML_CUDA_P2P=1 llama-cli -m model.gguf -sm tensor
|
||||
```
|
||||
|
||||
P2P requires driver support (usually restricted to workstation/datacenter GPUs) and **may cause crashes or corrupted outputs on some motherboards or BIOS configurations** (e.g. when IOMMU is enabled). If you see instability after enabling it, unset the variable.
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Symptom | How to fix |
|
||||
|---|---|
|
||||
| Startup error *"SPLIT_MODE_TENSOR requires flash_attn to be enabled"* | Add `-fa on` or remove `-fa off`. |
|
||||
| Startup error *"simultaneous use of SPLIT_MODE_TENSOR and KV cache quantization not implemented"* | Use `-ctk f16 -ctv f16` (or `bf16`/`f32`) with `--split-mode tensor`. |
|
||||
| Startup error *"LLAMA_SPLIT_MODE_TENSOR not implemented for architecture 'X'"* | Architecture not on the TENSOR allow-list. Use `--split-mode layer`. |
|
||||
| Warning *"NCCL is unavailable, multi GPU performance will be suboptimal"* | llama.cpp wasn't built with NCCL. Either accept the lower performance or install NCCL and rebuild. |
|
||||
| CUDA OOM at startup or during prefill in `--split-mode tensor` | Auto-fit is disabled in this mode, so reduce memory pressure yourself. In order from least to most disruptive: lower `--ctx-size` (`-c`) (KV cache is roughly proportional to `n_ctx`); for `llama-server`, lower `--parallel` (`-np`) (a slot KV cache is allocated per concurrent sequence); as a last resort, reduce `--n-gpu-layers` (`-ngl`) (the remaining layers run on CPU and inference will be much slower). |
|
||||
| Performance is worse with multi-GPU than single-GPU | The performance is bottlenecked by GPU interconnect speed. For `--split-mode tensor`, verify that NCCL is being used. Try `--split-mode layer` (less communication than `tensor`). Increase GPU interconnect speed via more PCIe lanes or e.g. NVLink (if available). |
|
||||
| GPU not used at all | `--n-gpu-layers` is `0` or too low - try explicitly setting `-ngl all`. Or you are accidentally hiding the GPUs via an environment variable like `CUDA_VISIBLE_DEVICES=-1`. Or your build doesn't include support for the relevant backend. |
|
||||
| Crashes or corrupted outputs after setting `GGML_CUDA_P2P=1` | Some motherboards and BIOS settings (e.g. with IOMMU enabled) don't support CUDA peer-to-peer reliably. Unset `GGML_CUDA_P2P`. |
|
||||
49
docs/multimodal/minicpmv4.6.md
Normal file
49
docs/multimodal/minicpmv4.6.md
Normal file
@@ -0,0 +1,49 @@
|
||||
## MiniCPM-V 4.6
|
||||
|
||||
### Prepare models and code
|
||||
|
||||
Download [MiniCPM-V-4_6](https://huggingface.co/openbmb/MiniCPM-V-4_6) PyTorch model from huggingface to "MiniCPM-V-4_6" folder.
|
||||
|
||||
The model must be the standard `transformers` v5.7.0+ checkpoint (no `trust_remote_code`); the architecture in `config.json` is `MiniCPMV4_6ForConditionalGeneration` with a `qwen3_5_text` text model and a SigLIP-based vision tower plus a window-attention `vit_merger`.
|
||||
|
||||
### Build llama.cpp
|
||||
|
||||
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md)
|
||||
|
||||
Clone llama.cpp:
|
||||
```bash
|
||||
git clone https://github.com/ggml-org/llama.cpp
|
||||
cd llama.cpp
|
||||
```
|
||||
|
||||
Build llama.cpp using `CMake`:
|
||||
```bash
|
||||
cmake -B build
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
|
||||
### Usage of MiniCPM-V 4.6
|
||||
|
||||
Unlike older MiniCPM-V variants, MiniCPM-V 4.6 is converted directly through `convert_hf_to_gguf.py`. The same script is invoked twice on the original Hugging Face directory: once to produce the language-model GGUF and once with `--mmproj` to produce the multimodal projector GGUF.
|
||||
|
||||
```bash
|
||||
# language model
|
||||
python ./convert_hf_to_gguf.py ../MiniCPM-V-4_6 --outfile ../MiniCPM-V-4_6/ggml-model-f16.gguf
|
||||
|
||||
# multimodal projector (vision tower + window-attention vit_merger + DownsampleMLP merger)
|
||||
python ./convert_hf_to_gguf.py ../MiniCPM-V-4_6 --mmproj --outfile ../MiniCPM-V-4_6/mmproj-model-f16.gguf
|
||||
|
||||
# optional: quantize to Q4_K_M
|
||||
./build/bin/llama-quantize ../MiniCPM-V-4_6/ggml-model-f16.gguf ../MiniCPM-V-4_6/ggml-model-Q4_K_M.gguf Q4_K_M
|
||||
```
|
||||
|
||||
|
||||
Inference on Linux or Mac
|
||||
```bash
|
||||
# run in single-turn mode
|
||||
./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4_6/ggml-model-f16.gguf --mmproj ../MiniCPM-V-4_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
|
||||
# run in conversation mode
|
||||
./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4_6/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-4_6/mmproj-model-f16.gguf
|
||||
```
|
||||
12
docs/ops.md
12
docs/ops.md
@@ -17,7 +17,7 @@ Legend:
|
||||
| ABS | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
@@ -36,15 +36,15 @@ Legend:
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DIAG | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DIAG | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
@@ -101,11 +101,11 @@ Legend:
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SUM | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
|
||||
10301
docs/ops/SYCL.csv
10301
docs/ops/SYCL.csv
File diff suppressed because it is too large
Load Diff
@@ -33,18 +33,18 @@ An example to use this approach can be the rewriting of source code by a LLM.
|
||||
This implementation looks for the last n-gram in history that matches the current n-gram and creates a draft using the m tokens following the matched n-gram. It is the simplest self-speculative approach with minimal overhead.
|
||||
|
||||
```
|
||||
llama-server [...] --spec-type ngram-simple --draft-max 64
|
||||
llama-server [...] --spec-type ngram-simple --spec-draft-n-max 64
|
||||
```
|
||||
|
||||
#### n-gram Map Key (`ngram-map-k`)
|
||||
|
||||
This implementation looks for the current n-gram of size n (called the _key_) in the token history. If the key n-gram is followed by the same m tokens (called the _mgram_) multiple times, it creates a draft using these m tokens. This approach requires a minimum number of occurrences (argument `--spec-ngram-min-hits`, default is 1) before generating drafts.
|
||||
This implementation looks for the current n-gram of size n (called the _key_) in the token history. If the key n-gram is followed by the same m tokens (called the _mgram_) multiple times, it creates a draft using these m tokens. This approach requires a minimum number of occurrences (argument `--spec-ngram-map-k-min-hits`, default is 1) before generating drafts.
|
||||
|
||||
The number of accepted tokens is stored for each used n-gram.
|
||||
|
||||
**Example:**
|
||||
```
|
||||
llama-server [...] --spec-type ngram-map-k --draft-max 64
|
||||
llama-server [...] --spec-type ngram-map-k --spec-draft-n-max 64
|
||||
```
|
||||
|
||||
#### n-gram Map Key-4-Values (`ngram-map-k4v`)
|
||||
@@ -55,7 +55,7 @@ The number of accepted tokens is stored for each used n-gram.
|
||||
|
||||
**Example:** Server options to be used if there are a lot of longer repetitions.
|
||||
```
|
||||
llama-server [...] --spec-type ngram-map-k4v --spec-ngram-size-n 8 --spec-ngram-size-m 8 --spec-ngram-min-hits 2 --draft-max 64
|
||||
llama-server [...] --spec-type ngram-map-k4v --spec-ngram-map-k4v-size-n 8 --spec-ngram-map-k4v-size-m 8 --spec-ngram-map-k4v-min-hits 2 --spec-draft-n-max 64
|
||||
```
|
||||
|
||||
### n-gram Mod (`ngram-mod`)
|
||||
@@ -80,9 +80,9 @@ Currently, a single hash pool is shared across all server slots, so different re
|
||||
# notes:
|
||||
# - small `n` are not recommended
|
||||
# - MoEs require long drafts
|
||||
# - dense models: can reduce `--draft-min` and `--draft-max`
|
||||
# - dense models: can reduce `--spec-ngram-mod-n-min` and `--spec-ngram-mod-n-max`
|
||||
|
||||
llama-server ... --spec-type ngram-mod --spec-ngram-size-n 24 --draft-min 48 --draft-max 64
|
||||
llama-server ... --spec-type ngram-mod --spec-ngram-mod-n-match 24 --spec-ngram-mod-n-min 48 --spec-ngram-mod-n-max 64
|
||||
```
|
||||
|
||||
Applications:
|
||||
@@ -105,21 +105,90 @@ Example Video:
|
||||
|
||||
If a draft model is combined with a draftless decoding the draftless decoding has higher precedence.
|
||||
|
||||
### General Speculative Parameters
|
||||
|
||||
```
|
||||
--draft, --draft-n, --draft-max N number of tokens to draft for speculative decoding (default: 16)
|
||||
(env: LLAMA_ARG_DRAFT_MAX)
|
||||
--draft-min, --draft-n-min N minimum number of draft tokens to use for speculative decoding
|
||||
(default: 0)
|
||||
(env: LLAMA_ARG_DRAFT_MIN)
|
||||
[...]
|
||||
--spec-type [none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
|
||||
type of speculative decoding to use when no draft model is provided
|
||||
(default: none)
|
||||
--spec-ngram-size-n N ngram size N for ngram-simple/ngram-map speculative decoding, length
|
||||
of lookup n-gram (default: 12)
|
||||
--spec-ngram-size-m N ngram size M for ngram-simple/ngram-map speculative decoding, length
|
||||
of draft m-gram (default: 48)
|
||||
--spec-ngram-min-hits N minimum hits for ngram-map speculative decoding (default: 1)
|
||||
(env: LLAMA_ARG_SPEC_TYPE)
|
||||
--spec-default use default speculative decoding
|
||||
```
|
||||
|
||||
### Draft Model Parameters
|
||||
|
||||
```
|
||||
--spec-draft-model, -md, --model-draft FNAME
|
||||
draft model for speculative decoding (default: unused)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_MODEL)
|
||||
--spec-draft-hf, -hfd, -hfrd, --hf-repo-draft <user>/<model>[:quant]
|
||||
HuggingFace repository for the draft model
|
||||
--spec-draft-n-max N
|
||||
number of tokens to draft for speculative decoding (default: 16)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_N_MAX)
|
||||
--spec-draft-n-min N
|
||||
minimum number of draft tokens to use for speculative decoding (default: 0)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_N_MIN)
|
||||
--spec-draft-p-split, --draft-p-split P
|
||||
speculative decoding split probability (default: 0.10)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT)
|
||||
--spec-draft-p-min, --draft-p-min P
|
||||
minimum speculative decoding probability (greedy) (default: 0.75)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_P_MIN)
|
||||
--spec-draft-ctx-size, -cd, --ctx-size-draft N
|
||||
size of the prompt context for the draft model (default: 0, 0 = loaded from model)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE)
|
||||
--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N
|
||||
max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)
|
||||
(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT)
|
||||
--spec-draft-device, -devd, --device-draft <dev1,dev2,..>
|
||||
comma-separated list of devices to use for offloading the draft model
|
||||
--spec-draft-replace, --spec-replace TARGET DRAFT
|
||||
translate the string in TARGET into DRAFT if the draft model and main model are not compatible
|
||||
```
|
||||
|
||||
### n-gram Mod Parameters
|
||||
|
||||
```
|
||||
--spec-ngram-mod-n-match N
|
||||
ngram-mod lookup length (default: 24)
|
||||
--spec-ngram-mod-n-min N
|
||||
minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48)
|
||||
--spec-ngram-mod-n-max N
|
||||
maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64)
|
||||
```
|
||||
|
||||
### n-gram Simple Parameters
|
||||
|
||||
```
|
||||
--spec-ngram-simple-size-n N
|
||||
ngram size N for ngram-simple speculative decoding, length of lookup n-gram (default: 12)
|
||||
--spec-ngram-simple-size-m N
|
||||
ngram size M for ngram-simple speculative decoding, length of draft m-gram (default: 48)
|
||||
--spec-ngram-simple-min-hits N
|
||||
minimum hits for ngram-simple speculative decoding (default: 1)
|
||||
```
|
||||
|
||||
### n-gram Map Key Parameters
|
||||
|
||||
```
|
||||
--spec-ngram-map-k-size-n N
|
||||
ngram size N for ngram-map-k speculative decoding, length of lookup n-gram (default: 12)
|
||||
--spec-ngram-map-k-size-m N
|
||||
ngram size M for ngram-map-k speculative decoding, length of draft m-gram (default: 48)
|
||||
--spec-ngram-map-k-min-hits N
|
||||
minimum hits for ngram-map-k speculative decoding (default: 1)
|
||||
```
|
||||
|
||||
### n-gram Map Key-4-Values Parameters
|
||||
|
||||
```
|
||||
--spec-ngram-map-k4v-size-n N
|
||||
ngram size N for ngram-map-k4v speculative decoding, length of lookup n-gram (default: 12)
|
||||
--spec-ngram-map-k4v-size-m N
|
||||
ngram size M for ngram-map-k4v speculative decoding, length of draft m-gram (default: 48)
|
||||
--spec-ngram-map-k4v-min-hits N
|
||||
minimum hits for ngram-map-k4v speculative decoding (default: 1)
|
||||
```
|
||||
|
||||
### `--spec-type TYPE`
|
||||
@@ -140,21 +209,40 @@ Specifies a type of speculative decoding without draft model.
|
||||
./llama-server [...] --spec-type ngram-simple
|
||||
```
|
||||
|
||||
### `--spec-ngram-size-n N`
|
||||
### `--spec-ngram-*-size-n N`
|
||||
|
||||
Sets the size N of the lookup n-gram for n-gram map based speculative decoding.
|
||||
The n-gram size N determines how many tokens in a row to look back when searching for matching patterns.
|
||||
|
||||
### `--spec-ngram-size-m M`
|
||||
Each n-gram implementation has its own parameter:
|
||||
|
||||
- `--spec-ngram-simple-size-n` for `ngram-simple`
|
||||
- `--spec-ngram-map-k-size-n` for `ngram-map-k`
|
||||
- `--spec-ngram-map-k4v-size-n` for `ngram-map-k4v`
|
||||
- `--spec-ngram-mod-n-match` for `ngram-mod`
|
||||
|
||||
### `--spec-ngram-*-size-m M`
|
||||
|
||||
Sets the size M of the draft m-gram for n-gram map based speculative decoding.
|
||||
The m-gram size determines how many tokens to draft when a match is found.
|
||||
Larger values can provide more speedup but may reduce acceptance rate.
|
||||
|
||||
### `--spec-ngram-min-hits H`
|
||||
Each n-gram implementation has its own parameter:
|
||||
|
||||
- `--spec-ngram-simple-size-m` for `ngram-simple`
|
||||
- `--spec-ngram-map-k-size-m` for `ngram-map-k`
|
||||
- `--spec-ngram-map-k4v-size-m` for `ngram-map-k4v`
|
||||
|
||||
### `--spec-ngram-*-min-hits H`
|
||||
|
||||
This option defines how often a key has to appear in the token history to be used as a draft (default is 1).
|
||||
|
||||
Each n-gram implementation has its own parameter:
|
||||
|
||||
- `--spec-ngram-simple-min-hits` for `ngram-simple`
|
||||
- `--spec-ngram-map-k-min-hits` for `ngram-map-k`
|
||||
- `--spec-ngram-map-k4v-min-hits` for `ngram-map-k4v`
|
||||
|
||||
## Statistics
|
||||
Each speculative decoding implementation prints statistics.
|
||||
|
||||
@@ -180,4 +268,3 @@ statistics ngram_map_k: #calls(b,g,a) = 6 1690 26, #gen drafts = 26, #acc drafts
|
||||
- `#gen tokens`: number of tokens generated by this implementation (including rejected tokens)
|
||||
- `#acc tokens`: number of tokens accepted by the main model
|
||||
- `dur(b,g,a): durations of begin (new prompt), generation and accumulation (process acceptance).
|
||||
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
set(TARGET llama-diffusion)
|
||||
add_library(${TARGET} STATIC diffusion.cpp diffusion.h)
|
||||
target_link_libraries(${TARGET} PUBLIC llama llama-common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PUBLIC cxx_std_17)
|
||||
|
||||
set(TARGET llama-diffusion-cli)
|
||||
add_executable(${TARGET} diffusion-cli.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE llama llama-common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-diffusion llama llama-common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -12,11 +12,11 @@ The diffusion CLI supports various parameters to control the generation process:
|
||||
### Core Diffusion Parameters
|
||||
- `--diffusion-steps`: Number of diffusion steps (default: 256)
|
||||
- `--diffusion-algorithm`: Algorithm for token selection
|
||||
- `0`: ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006.
|
||||
- `1`: ENTROPY_BASED - Entropy-based selection
|
||||
- `2`: MARGIN_BASED - Margin-based selection
|
||||
- `3`: RANDOM - Random selection
|
||||
- `4`: CONFIDENCE_BASED - Confidence-based selection (default)
|
||||
- `0`: DIFFUSION_ALGORITHM_ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006.
|
||||
- `1`: DIFFUSION_ALGORITHM_ENTROPY_BASED - Entropy-based selection
|
||||
- `2`: DIFFUSION_ALGORITHM_MARGIN_BASED - Margin-based selection
|
||||
- `3`: DIFFUSION_ALGORITHM_RANDOM - Random selection
|
||||
- `4`: DIFFUSION_ALGORITHM_CONFIDENCE_BASED - Confidence-based selection (default)
|
||||
- More documentation here https://github.com/DreamLM/Dream
|
||||
- `--diffusion-visual`: Enable live visualization during generation
|
||||
|
||||
|
||||
@@ -1,127 +1,23 @@
|
||||
#include "arg.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "diffusion.h"
|
||||
#include "llama.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <limits.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <clocale>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
enum diffusion_algorithm { ORIGIN = 0, ENTROPY_BASED = 1, MARGIN_BASED = 2, RANDOM = 3, CONFIDENCE_BASED = 4 };
|
||||
|
||||
// Unified transfer scheduling methods
|
||||
enum transfer_schedule {
|
||||
TIMESTEP_BASED = 0, // Dream-style: (1.0 - s/t) * remaining
|
||||
BLOCK_BASED = 1, // LLaDA-style: process in blocks with get_num_transfer_tokens
|
||||
};
|
||||
|
||||
typedef bool (*diffusion_step_callback_t)(int32_t step,
|
||||
int32_t total_steps,
|
||||
const llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
void * user_data);
|
||||
|
||||
struct diffusion_params {
|
||||
int32_t steps = 0;
|
||||
float temperature = 0;
|
||||
llama_token mask_token_id = LLAMA_TOKEN_NULL;
|
||||
diffusion_step_callback_t step_callback = nullptr;
|
||||
void * step_callback_user_data = nullptr;
|
||||
int32_t seed = 0;
|
||||
bool visual_mode = false;
|
||||
bool shift_logits = false; // Shift logits by -1 after decode
|
||||
|
||||
float top_p = 0.;
|
||||
int32_t top_k = 0.;
|
||||
|
||||
diffusion_algorithm algorithm = CONFIDENCE_BASED;
|
||||
transfer_schedule schedule = TIMESTEP_BASED;
|
||||
|
||||
float cfg_scale = 0.; // Config scale for classifier-free guidance
|
||||
float eps = 0.; // Timestep scheduling
|
||||
int32_t block_length = 0; // Block size (for block scheduling)
|
||||
float alg_temp = 0; // algorithm temperature (0.0 = deterministic)
|
||||
bool add_gumbel_noise = false; // Add gumbel noise to the logits if temp > 0.0
|
||||
|
||||
int32_t max_length = 0; // Maximum sequence length
|
||||
};
|
||||
|
||||
struct callback_data {
|
||||
diffusion_params * diff_params;
|
||||
const llama_vocab * vocab;
|
||||
int32_t n_input;
|
||||
};
|
||||
|
||||
static float calculate_confidence(const llama_token_data_array & cur_p,
|
||||
diffusion_algorithm algorithm,
|
||||
std::mt19937 & rng) {
|
||||
switch (algorithm) {
|
||||
case CONFIDENCE_BASED:
|
||||
return cur_p.data[cur_p.selected].p; // Selected token probability
|
||||
|
||||
case ENTROPY_BASED:
|
||||
{
|
||||
float entropy = 0.0f;
|
||||
const float epsilon = 1e-10f;
|
||||
for (size_t i = 0; i < cur_p.size; i++) {
|
||||
float prob = cur_p.data[i].p;
|
||||
entropy += prob * logf(prob + epsilon);
|
||||
}
|
||||
return -entropy; // Higher entropy = lower confidence
|
||||
}
|
||||
|
||||
case MARGIN_BASED:
|
||||
return (cur_p.size > 1) ? cur_p.data[0].p - cur_p.data[1].p : cur_p.data[0].p;
|
||||
|
||||
case RANDOM:
|
||||
{
|
||||
std::uniform_real_distribution<float> uniform(0.0f, 1.0f);
|
||||
return uniform(rng); // Random confidence
|
||||
}
|
||||
|
||||
case ORIGIN:
|
||||
return cur_p.data[cur_p.selected].p;
|
||||
|
||||
default:
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// Unified transfer count calculation function
|
||||
static int32_t calculate_transfer_count(int32_t step,
|
||||
int32_t total_steps,
|
||||
int32_t remaining_masked,
|
||||
transfer_schedule schedule,
|
||||
float eps,
|
||||
const std::vector<int32_t> & num_transfer_tokens = {}) {
|
||||
switch (schedule) {
|
||||
case TIMESTEP_BASED:
|
||||
{
|
||||
float t = 1.0f - (float) step / total_steps * (1.0f - eps);
|
||||
float s = 1.0f - (float) (step + 1) / total_steps * (1.0f - eps);
|
||||
float p_transfer = (step < total_steps - 1) ? (1.0f - s / t) : 1.0f;
|
||||
return (int32_t) (remaining_masked * p_transfer);
|
||||
}
|
||||
|
||||
case BLOCK_BASED:
|
||||
if (!num_transfer_tokens.empty() && step < (int32_t) num_transfer_tokens.size()) {
|
||||
return num_transfer_tokens[step];
|
||||
}
|
||||
return remaining_masked / (total_steps - step); // Fallback
|
||||
|
||||
default:
|
||||
return remaining_masked / (total_steps - step);
|
||||
}
|
||||
}
|
||||
|
||||
static bool diffusion_step_callback(int32_t step,
|
||||
int32_t total_steps,
|
||||
const llama_token * tokens,
|
||||
@@ -176,341 +72,6 @@ static bool diffusion_step_callback(int32_t step,
|
||||
return true;
|
||||
}
|
||||
|
||||
static void add_gumbel_noise(float * logits, int32_t n_vocab, float temperature, std::mt19937 & rng) {
|
||||
if (temperature == 0.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::uniform_real_distribution<double> uniform(0.0, 1.0);
|
||||
for (int32_t i = 0; i < n_vocab; i++) {
|
||||
double noise = uniform(rng);
|
||||
// Prevent log(0)
|
||||
noise = std::max(noise, 1e-20);
|
||||
double gumbel_noise = std::pow(-std::log(noise), temperature);
|
||||
logits[i] = std::exp(logits[i]) / gumbel_noise;
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<int32_t> get_num_transfer_tokens(int32_t mask_count, int32_t steps) {
|
||||
std::vector<int32_t> num_transfer_tokens(steps);
|
||||
|
||||
int32_t base = mask_count / steps;
|
||||
int32_t remainder = mask_count % steps;
|
||||
|
||||
for (int32_t i = 0; i < steps; i++) {
|
||||
num_transfer_tokens[i] = base + (i < remainder ? 1 : 0);
|
||||
}
|
||||
|
||||
return num_transfer_tokens;
|
||||
}
|
||||
|
||||
static void diffusion_generate(llama_context * ctx,
|
||||
const llama_token * input_tokens,
|
||||
llama_token * output_tokens,
|
||||
int32_t n_input,
|
||||
const diffusion_params & params,
|
||||
int32_t & n_generated) {
|
||||
n_generated = 0;
|
||||
if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || params.max_length <= n_input) {
|
||||
return;
|
||||
}
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
|
||||
// Initialize with input and pad with mask tokens
|
||||
std::copy(input_tokens, input_tokens + n_input, output_tokens);
|
||||
std::fill(output_tokens + n_input, output_tokens + params.max_length, params.mask_token_id);
|
||||
|
||||
std::mt19937 rng(params.seed);
|
||||
|
||||
llama_set_causal_attn(ctx, false);
|
||||
|
||||
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
|
||||
|
||||
std::vector<llama_token_data> candidates(n_vocab);
|
||||
std::vector<llama_token_data> conf_candidates;
|
||||
conf_candidates.reserve(params.max_length);
|
||||
std::vector<int32_t> mask_positions;
|
||||
mask_positions.reserve(params.max_length);
|
||||
|
||||
// Setup sampler chain
|
||||
struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
|
||||
if (params.top_k > 0) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
|
||||
}
|
||||
if (params.top_p < 1.0f) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1));
|
||||
}
|
||||
if (params.temperature > 0.0f) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature));
|
||||
}
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed));
|
||||
|
||||
struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
|
||||
|
||||
llama_batch batch = llama_batch_init(params.max_length, 0, 1);
|
||||
batch.n_tokens = params.max_length;
|
||||
|
||||
// Pre-allocate buffers for CFG if needed
|
||||
int32_t logits_size = n_vocab * params.max_length;
|
||||
std::vector<float> cond_logits_buffer;
|
||||
std::vector<llama_token> un_x_buffer;
|
||||
if (params.cfg_scale > 0.0f) {
|
||||
cond_logits_buffer.resize(logits_size);
|
||||
un_x_buffer.resize(params.max_length);
|
||||
}
|
||||
|
||||
// For block-based processing
|
||||
std::vector<int32_t> num_transfer_tokens;
|
||||
int32_t num_blocks = 1;
|
||||
int32_t steps_per_block = params.steps;
|
||||
|
||||
if (params.schedule == BLOCK_BASED) {
|
||||
GGML_ASSERT(params.max_length % params.block_length == 0);
|
||||
num_blocks = params.max_length / params.block_length;
|
||||
GGML_ASSERT(params.steps % num_blocks == 0);
|
||||
steps_per_block = params.steps / num_blocks;
|
||||
}
|
||||
|
||||
std::vector<float> confidence(params.max_length);
|
||||
|
||||
int64_t total_sampling_time = 0;
|
||||
int64_t total_time = 0;
|
||||
int64_t time_start = ggml_time_us();
|
||||
|
||||
for (int block_num = 0; block_num < num_blocks; block_num++) {
|
||||
int32_t block_start = (params.schedule == BLOCK_BASED) ? n_input + block_num * params.block_length : 0;
|
||||
int32_t block_end = (params.schedule == BLOCK_BASED) ?
|
||||
std::min(n_input + (block_num + 1) * params.block_length, params.max_length) :
|
||||
params.max_length;
|
||||
|
||||
// Count masked tokens in current block for block-based processing
|
||||
if (params.schedule == BLOCK_BASED) {
|
||||
int32_t block_mask_count = 0;
|
||||
for (int i = block_start; i < block_end; i++) {
|
||||
if (output_tokens[i] == params.mask_token_id) {
|
||||
block_mask_count++;
|
||||
}
|
||||
}
|
||||
num_transfer_tokens = get_num_transfer_tokens(block_mask_count, steps_per_block);
|
||||
}
|
||||
|
||||
for (int32_t step = 0; step < steps_per_block; step++) {
|
||||
int32_t global_step = block_num * steps_per_block + step;
|
||||
|
||||
if (params.step_callback) {
|
||||
if (!params.step_callback(
|
||||
global_step, params.steps, output_tokens, params.max_length, params.step_callback_user_data)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Setup batch
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
batch.token[i] = output_tokens[i];
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = 0;
|
||||
batch.logits[i] = 1;
|
||||
}
|
||||
|
||||
float * logits = nullptr;
|
||||
|
||||
if (params.cfg_scale > 0.0f) {
|
||||
int ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("Failed to generate conditional");
|
||||
break;
|
||||
}
|
||||
float * cond_logits_ptr = llama_get_logits(ctx);
|
||||
std::memcpy(cond_logits_buffer.data(), cond_logits_ptr, logits_size * sizeof(float));
|
||||
|
||||
// Unconditional generation (mask input)
|
||||
std::copy(output_tokens, output_tokens + params.max_length, un_x_buffer.begin());
|
||||
for (int32_t i = 0; i < n_input; i++) {
|
||||
un_x_buffer[i] = params.mask_token_id;
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
batch.token[i] = un_x_buffer[i];
|
||||
}
|
||||
ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("Failed to generate unconditional");
|
||||
break;
|
||||
}
|
||||
float * uncond_logits = llama_get_logits(ctx);
|
||||
|
||||
// Apply CFG
|
||||
for (int32_t i = 0; i < logits_size; i++) {
|
||||
cond_logits_buffer[i] =
|
||||
uncond_logits[i] + (params.cfg_scale + 1.0f) * (cond_logits_buffer[i] - uncond_logits[i]);
|
||||
}
|
||||
logits = cond_logits_buffer.data();
|
||||
} else {
|
||||
int ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, global_step, ret);
|
||||
break;
|
||||
}
|
||||
logits = llama_get_logits(ctx);
|
||||
}
|
||||
|
||||
if (!logits) {
|
||||
LOG_ERR("%s: failed to get logits at step %d\n", __func__, global_step);
|
||||
break;
|
||||
}
|
||||
|
||||
auto get_logits_for_pos = [&](int32_t pos) -> const float * {
|
||||
if (params.shift_logits) {
|
||||
return pos == 0 ? logits : logits + (pos - 1) * n_vocab;
|
||||
}
|
||||
return logits + (pos) *n_vocab;
|
||||
};
|
||||
|
||||
int64_t time_start_sampling = ggml_time_us();
|
||||
|
||||
mask_positions.clear();
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
if (output_tokens[i] == params.mask_token_id) {
|
||||
// For block-based, only consider current block
|
||||
if (params.schedule != BLOCK_BASED || (i >= block_start && i < block_end)) {
|
||||
mask_positions.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (mask_positions.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (params.add_gumbel_noise && params.temperature > 0.0f) {
|
||||
add_gumbel_noise(logits, n_vocab, params.temperature, rng);
|
||||
}
|
||||
|
||||
if (params.algorithm == ORIGIN) {
|
||||
int32_t transfer_count = calculate_transfer_count(
|
||||
step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
|
||||
float p_transfer = (float) transfer_count / mask_positions.size();
|
||||
|
||||
for (int32_t pos : mask_positions) {
|
||||
if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
|
||||
const float * pos_logits = get_logits_for_pos(pos);
|
||||
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].id = token_id;
|
||||
candidates[token_id].logit = pos_logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
candidates.data(),
|
||||
(size_t) n_vocab,
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
output_tokens[pos] = cur_p.data[cur_p.selected].id;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::vector<std::pair<float, int32_t>> confidences;
|
||||
std::vector<llama_token> sampled_tokens(mask_positions.size());
|
||||
|
||||
for (size_t i = 0; i < mask_positions.size(); i++) {
|
||||
int32_t pos = mask_positions[i];
|
||||
const float * pos_logits = get_logits_for_pos(pos);
|
||||
|
||||
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].logit = pos_logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
candidates[token_id].id = token_id;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
candidates.data(),
|
||||
candidates.size(),
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
llama_token sampled_token = cur_p.data[cur_p.selected].id;
|
||||
|
||||
float conf = calculate_confidence(cur_p, params.algorithm, rng);
|
||||
|
||||
sampled_tokens[i] = sampled_token;
|
||||
confidences.emplace_back(conf, i);
|
||||
}
|
||||
|
||||
int32_t transfer_count = calculate_transfer_count(
|
||||
step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
|
||||
|
||||
if (transfer_count > 0) {
|
||||
if (params.alg_temp == 0.0f) {
|
||||
std::partial_sort(confidences.begin(),
|
||||
confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()),
|
||||
confidences.end(),
|
||||
[](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
|
||||
if (a.first != b.first) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
return a.second < b.second;
|
||||
});
|
||||
|
||||
for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
|
||||
int32_t mask_idx = confidences[i].second;
|
||||
int32_t pos = mask_positions[mask_idx];
|
||||
output_tokens[pos] = sampled_tokens[mask_idx];
|
||||
}
|
||||
} else {
|
||||
conf_candidates.clear();
|
||||
for (size_t i = 0; i < confidences.size(); i++) {
|
||||
float conf_logit = confidences[i].first / params.alg_temp;
|
||||
conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f });
|
||||
}
|
||||
|
||||
llama_token_data_array conf_array = {
|
||||
conf_candidates.data(),
|
||||
conf_candidates.size(),
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
|
||||
llama_sampler_apply(dist_sampler, &conf_array);
|
||||
int32_t selected_idx = conf_array.selected;
|
||||
int32_t mask_idx = selected_idx;
|
||||
int32_t pos = mask_positions[mask_idx];
|
||||
output_tokens[pos] = sampled_tokens[mask_idx];
|
||||
|
||||
conf_candidates[selected_idx].p = 0.0f;
|
||||
conf_array.selected = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int64_t time_end_sampling = ggml_time_us();
|
||||
total_sampling_time += time_end_sampling - time_start_sampling;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t time_end = ggml_time_us();
|
||||
total_time += time_end - time_start;
|
||||
|
||||
LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
|
||||
total_time / 1000.0,
|
||||
total_time / 1000.0 / params.steps,
|
||||
total_sampling_time / 1000.0 / params.steps);
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_sampler_free(sampler);
|
||||
llama_sampler_free(dist_sampler);
|
||||
|
||||
n_generated = params.max_length;
|
||||
}
|
||||
|
||||
static std::string format_input_text(const std::string & prompt, const std::string & system_prompt, bool use_chat_template, llama_model * model) {
|
||||
if (!use_chat_template) {
|
||||
return prompt;
|
||||
@@ -631,10 +192,10 @@ int main(int argc, char ** argv) {
|
||||
GGML_ASSERT((params.diffusion.eps == 0) ^ (params.diffusion.block_length == 0));
|
||||
|
||||
if (params.diffusion.eps) {
|
||||
diff_params.schedule = TIMESTEP_BASED;
|
||||
diff_params.schedule = DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED;
|
||||
diff_params.eps = params.diffusion.eps;
|
||||
} else if (params.diffusion.block_length) {
|
||||
diff_params.schedule = BLOCK_BASED;
|
||||
diff_params.schedule = DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED;
|
||||
diff_params.block_length = params.diffusion.block_length;
|
||||
}
|
||||
|
||||
@@ -653,8 +214,17 @@ int main(int argc, char ** argv) {
|
||||
callback_data cb_data = { &diff_params, vocab, n_input };
|
||||
diff_params.step_callback_user_data = &cb_data;
|
||||
|
||||
const char * alg_names[] = { "ORIGIN", "ENTROPY_BASED", "MARGIN_BASED", "RANDOM", "CONFIDENCE_BASED" };
|
||||
const char * sched_names[] = { "TIMESTEP_BASED", "BLOCK_BASED" };
|
||||
const char * alg_names[] = {
|
||||
"DIFFUSION_ALGORITHM_ORIGIN",
|
||||
"DIFFUSION_ALGORITHM_ENTROPY_BASED",
|
||||
"DIFFUSION_ALGORITHM_MARGIN_BASED",
|
||||
"DIFFUSION_ALGORITHM_RANDOM",
|
||||
"DIFFUSION_ALGORITHM_CONFIDENCE_BASED",
|
||||
};
|
||||
const char * sched_names[] = {
|
||||
"DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED",
|
||||
"DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED",
|
||||
};
|
||||
const char * alg_name =
|
||||
(diff_params.algorithm >= 0 && diff_params.algorithm <= 4) ? alg_names[diff_params.algorithm] : "UNKNOWN";
|
||||
const char * sched_name =
|
||||
@@ -666,11 +236,11 @@ int main(int argc, char ** argv) {
|
||||
LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "algorithm", diff_params.algorithm, alg_name);
|
||||
LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "schedule", diff_params.schedule, sched_name);
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "temperature", diff_params.temperature);
|
||||
if (diff_params.schedule == TIMESTEP_BASED) {
|
||||
if (diff_params.schedule == DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED) {
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", diff_params.eps);
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", diff_params.alg_temp);
|
||||
}
|
||||
if (diff_params.schedule == BLOCK_BASED) {
|
||||
if (diff_params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) {
|
||||
LOG_INF("diffusion_params: - %-25s u32 = %d\n", "block_length", diff_params.block_length);
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "cfg_scale", diff_params.cfg_scale);
|
||||
}
|
||||
|
||||
408
examples/diffusion/diffusion.cpp
Normal file
408
examples/diffusion/diffusion.cpp
Normal file
@@ -0,0 +1,408 @@
|
||||
#include "diffusion.h"
|
||||
|
||||
#include "log.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <random>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
static float calculate_confidence(const llama_token_data_array & cur_p,
|
||||
diffusion_algorithm algorithm,
|
||||
std::mt19937 & rng) {
|
||||
switch (algorithm) {
|
||||
case DIFFUSION_ALGORITHM_CONFIDENCE_BASED:
|
||||
return cur_p.data[cur_p.selected].p; // Selected token probability
|
||||
|
||||
case DIFFUSION_ALGORITHM_ENTROPY_BASED:
|
||||
{
|
||||
float entropy = 0.0f;
|
||||
const float epsilon = 1e-10f;
|
||||
for (size_t i = 0; i < cur_p.size; i++) {
|
||||
float prob = cur_p.data[i].p;
|
||||
entropy += prob * logf(prob + epsilon);
|
||||
}
|
||||
return -entropy; // Higher entropy = lower confidence
|
||||
}
|
||||
|
||||
case DIFFUSION_ALGORITHM_MARGIN_BASED:
|
||||
return (cur_p.size > 1) ? cur_p.data[0].p - cur_p.data[1].p : cur_p.data[0].p;
|
||||
|
||||
case DIFFUSION_ALGORITHM_RANDOM:
|
||||
{
|
||||
std::uniform_real_distribution<float> uniform(0.0f, 1.0f);
|
||||
return uniform(rng); // Random confidence
|
||||
}
|
||||
|
||||
case DIFFUSION_ALGORITHM_ORIGIN:
|
||||
return cur_p.data[cur_p.selected].p;
|
||||
|
||||
default:
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// Unified transfer count calculation function
|
||||
static int32_t calculate_transfer_count(int32_t step,
|
||||
int32_t total_steps,
|
||||
int32_t remaining_masked,
|
||||
diffusion_transfer_schedule schedule,
|
||||
float eps,
|
||||
const std::vector<int32_t> & num_transfer_tokens = {}) {
|
||||
switch (schedule) {
|
||||
case DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED:
|
||||
{
|
||||
float t = 1.0f - (float) step / total_steps * (1.0f - eps);
|
||||
float s = 1.0f - (float) (step + 1) / total_steps * (1.0f - eps);
|
||||
float p_transfer = (step < total_steps - 1) ? (1.0f - s / t) : 1.0f;
|
||||
return (int32_t) (remaining_masked * p_transfer);
|
||||
}
|
||||
|
||||
case DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED:
|
||||
if (!num_transfer_tokens.empty() && step < (int32_t) num_transfer_tokens.size()) {
|
||||
return num_transfer_tokens[step];
|
||||
}
|
||||
return remaining_masked / (total_steps - step); // Fallback
|
||||
|
||||
default:
|
||||
return remaining_masked / (total_steps - step);
|
||||
}
|
||||
}
|
||||
|
||||
static void add_gumbel_noise(float * logits, int32_t n_vocab, float temperature, std::mt19937 & rng) {
|
||||
if (temperature == 0.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::uniform_real_distribution<double> uniform(0.0, 1.0);
|
||||
for (int32_t i = 0; i < n_vocab; i++) {
|
||||
double noise = uniform(rng);
|
||||
// Prevent log(0)
|
||||
noise = std::max(noise, 1e-20);
|
||||
double gumbel_noise = std::pow(-std::log(noise), temperature);
|
||||
logits[i] = std::exp(logits[i]) / gumbel_noise;
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<int32_t> get_num_transfer_tokens(int32_t mask_count, int32_t steps) {
|
||||
std::vector<int32_t> num_transfer_tokens(steps);
|
||||
|
||||
int32_t base = mask_count / steps;
|
||||
int32_t remainder = mask_count % steps;
|
||||
|
||||
for (int32_t i = 0; i < steps; i++) {
|
||||
num_transfer_tokens[i] = base + (i < remainder ? 1 : 0);
|
||||
}
|
||||
|
||||
return num_transfer_tokens;
|
||||
}
|
||||
|
||||
void diffusion_generate(llama_context * ctx,
|
||||
const llama_token * input_tokens,
|
||||
llama_token * output_tokens,
|
||||
int32_t n_input,
|
||||
const diffusion_params & params,
|
||||
int32_t & n_generated) {
|
||||
n_generated = 0;
|
||||
if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || params.max_length <= n_input) {
|
||||
return;
|
||||
}
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
|
||||
// Initialize with input and pad with mask tokens
|
||||
std::copy(input_tokens, input_tokens + n_input, output_tokens);
|
||||
std::fill(output_tokens + n_input, output_tokens + params.max_length, params.mask_token_id);
|
||||
|
||||
std::mt19937 rng(params.seed);
|
||||
|
||||
llama_set_causal_attn(ctx, false);
|
||||
|
||||
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
|
||||
|
||||
std::vector<llama_token_data> candidates(n_vocab);
|
||||
std::vector<llama_token_data> conf_candidates;
|
||||
conf_candidates.reserve(params.max_length);
|
||||
std::vector<int32_t> mask_positions;
|
||||
mask_positions.reserve(params.max_length);
|
||||
|
||||
// Setup sampler chain
|
||||
struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
|
||||
if (params.top_k > 0) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
|
||||
}
|
||||
if (params.top_p < 1.0f) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1));
|
||||
}
|
||||
if (params.temperature > 0.0f) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature));
|
||||
}
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed));
|
||||
|
||||
struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
|
||||
|
||||
llama_batch batch = llama_batch_init(params.max_length, 0, 1);
|
||||
batch.n_tokens = params.max_length;
|
||||
|
||||
// Pre-allocate buffers for CFG if needed
|
||||
int32_t logits_size = n_vocab * params.max_length;
|
||||
std::vector<float> cond_logits_buffer;
|
||||
std::vector<llama_token> un_x_buffer;
|
||||
if (params.cfg_scale > 0.0f) {
|
||||
cond_logits_buffer.resize(logits_size);
|
||||
un_x_buffer.resize(params.max_length);
|
||||
}
|
||||
|
||||
// For block-based processing
|
||||
std::vector<int32_t> num_transfer_tokens;
|
||||
int32_t num_blocks = 1;
|
||||
int32_t steps_per_block = params.steps;
|
||||
|
||||
if (params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) {
|
||||
GGML_ASSERT(params.max_length % params.block_length == 0);
|
||||
num_blocks = params.max_length / params.block_length;
|
||||
GGML_ASSERT(params.steps % num_blocks == 0);
|
||||
steps_per_block = params.steps / num_blocks;
|
||||
}
|
||||
|
||||
std::vector<float> confidence(params.max_length);
|
||||
|
||||
int64_t total_sampling_time = 0;
|
||||
int64_t total_time = 0;
|
||||
int64_t time_start = ggml_time_us();
|
||||
|
||||
for (int block_num = 0; block_num < num_blocks; block_num++) {
|
||||
int32_t block_start = (params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) ? n_input + block_num * params.block_length : 0;
|
||||
int32_t block_end = (params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) ?
|
||||
std::min(n_input + (block_num + 1) * params.block_length, params.max_length) :
|
||||
params.max_length;
|
||||
|
||||
// Count masked tokens in current block for block-based processing
|
||||
if (params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) {
|
||||
int32_t block_mask_count = 0;
|
||||
for (int i = block_start; i < block_end; i++) {
|
||||
if (output_tokens[i] == params.mask_token_id) {
|
||||
block_mask_count++;
|
||||
}
|
||||
}
|
||||
num_transfer_tokens = get_num_transfer_tokens(block_mask_count, steps_per_block);
|
||||
}
|
||||
|
||||
for (int32_t step = 0; step < steps_per_block; step++) {
|
||||
int32_t global_step = block_num * steps_per_block + step;
|
||||
|
||||
if (params.step_callback) {
|
||||
if (!params.step_callback(
|
||||
global_step, params.steps, output_tokens, params.max_length, params.step_callback_user_data)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Setup batch
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
batch.token[i] = output_tokens[i];
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = 0;
|
||||
batch.logits[i] = 1;
|
||||
}
|
||||
|
||||
float * logits = nullptr;
|
||||
|
||||
if (params.cfg_scale > 0.0f) {
|
||||
int ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("Failed to generate conditional");
|
||||
break;
|
||||
}
|
||||
float * cond_logits_ptr = llama_get_logits(ctx);
|
||||
std::memcpy(cond_logits_buffer.data(), cond_logits_ptr, logits_size * sizeof(float));
|
||||
|
||||
// Unconditional generation (mask input)
|
||||
std::copy(output_tokens, output_tokens + params.max_length, un_x_buffer.begin());
|
||||
for (int32_t i = 0; i < n_input; i++) {
|
||||
un_x_buffer[i] = params.mask_token_id;
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
batch.token[i] = un_x_buffer[i];
|
||||
}
|
||||
ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("Failed to generate unconditional");
|
||||
break;
|
||||
}
|
||||
float * uncond_logits = llama_get_logits(ctx);
|
||||
|
||||
// Apply CFG
|
||||
for (int32_t i = 0; i < logits_size; i++) {
|
||||
cond_logits_buffer[i] =
|
||||
uncond_logits[i] + (params.cfg_scale + 1.0f) * (cond_logits_buffer[i] - uncond_logits[i]);
|
||||
}
|
||||
logits = cond_logits_buffer.data();
|
||||
} else {
|
||||
int ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, global_step, ret);
|
||||
break;
|
||||
}
|
||||
logits = llama_get_logits(ctx);
|
||||
}
|
||||
|
||||
if (!logits) {
|
||||
LOG_ERR("%s: failed to get logits at step %d\n", __func__, global_step);
|
||||
break;
|
||||
}
|
||||
|
||||
auto get_logits_for_pos = [&](int32_t pos) -> const float * {
|
||||
if (params.shift_logits) {
|
||||
return pos == 0 ? logits : logits + (pos - 1) * n_vocab;
|
||||
}
|
||||
return logits + pos * n_vocab;
|
||||
};
|
||||
|
||||
int64_t time_start_sampling = ggml_time_us();
|
||||
|
||||
mask_positions.clear();
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
if (output_tokens[i] == params.mask_token_id) {
|
||||
// For block-based, only consider current block
|
||||
if (params.schedule != DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED || (i >= block_start && i < block_end)) {
|
||||
mask_positions.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (mask_positions.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (params.add_gumbel_noise && params.temperature > 0.0f) {
|
||||
add_gumbel_noise(logits, n_vocab, params.temperature, rng);
|
||||
}
|
||||
|
||||
if (params.algorithm == DIFFUSION_ALGORITHM_ORIGIN) {
|
||||
int32_t transfer_count = calculate_transfer_count(
|
||||
step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
|
||||
float p_transfer = (float) transfer_count / mask_positions.size();
|
||||
|
||||
for (int32_t pos : mask_positions) {
|
||||
if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
|
||||
const float * pos_logits = get_logits_for_pos(pos);
|
||||
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].id = token_id;
|
||||
candidates[token_id].logit = pos_logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
candidates.data(),
|
||||
(size_t) n_vocab,
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
output_tokens[pos] = cur_p.data[cur_p.selected].id;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::vector<std::pair<float, int32_t>> confidences;
|
||||
std::vector<llama_token> sampled_tokens(mask_positions.size());
|
||||
|
||||
for (size_t i = 0; i < mask_positions.size(); i++) {
|
||||
int32_t pos = mask_positions[i];
|
||||
const float * pos_logits = get_logits_for_pos(pos);
|
||||
|
||||
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].logit = pos_logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
candidates[token_id].id = token_id;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
candidates.data(),
|
||||
candidates.size(),
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
llama_token sampled_token = cur_p.data[cur_p.selected].id;
|
||||
|
||||
float conf = calculate_confidence(cur_p, params.algorithm, rng);
|
||||
|
||||
sampled_tokens[i] = sampled_token;
|
||||
confidences.emplace_back(conf, i);
|
||||
}
|
||||
|
||||
int32_t transfer_count = calculate_transfer_count(
|
||||
step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
|
||||
|
||||
if (transfer_count > 0) {
|
||||
if (params.alg_temp == 0.0f) {
|
||||
std::partial_sort(confidences.begin(),
|
||||
confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()),
|
||||
confidences.end(),
|
||||
[](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
|
||||
if (a.first != b.first) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
return a.second < b.second;
|
||||
});
|
||||
|
||||
for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
|
||||
int32_t mask_idx = confidences[i].second;
|
||||
int32_t pos = mask_positions[mask_idx];
|
||||
output_tokens[pos] = sampled_tokens[mask_idx];
|
||||
}
|
||||
} else {
|
||||
conf_candidates.clear();
|
||||
for (size_t i = 0; i < confidences.size(); i++) {
|
||||
float conf_logit = confidences[i].first / params.alg_temp;
|
||||
conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f });
|
||||
}
|
||||
|
||||
llama_token_data_array conf_array = {
|
||||
conf_candidates.data(),
|
||||
conf_candidates.size(),
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
|
||||
llama_sampler_apply(dist_sampler, &conf_array);
|
||||
int32_t selected_idx = conf_array.selected;
|
||||
int32_t mask_idx = selected_idx;
|
||||
int32_t pos = mask_positions[mask_idx];
|
||||
output_tokens[pos] = sampled_tokens[mask_idx];
|
||||
|
||||
conf_candidates[selected_idx].p = 0.0f;
|
||||
conf_array.selected = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int64_t time_end_sampling = ggml_time_us();
|
||||
total_sampling_time += time_end_sampling - time_start_sampling;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t time_end = ggml_time_us();
|
||||
total_time += time_end - time_start;
|
||||
|
||||
LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
|
||||
total_time / 1000.0,
|
||||
total_time / 1000.0 / params.steps,
|
||||
total_sampling_time / 1000.0 / params.steps);
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_sampler_free(sampler);
|
||||
llama_sampler_free(dist_sampler);
|
||||
|
||||
n_generated = params.max_length;
|
||||
}
|
||||
57
examples/diffusion/diffusion.h
Normal file
57
examples/diffusion/diffusion.h
Normal file
@@ -0,0 +1,57 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
enum diffusion_algorithm {
|
||||
DIFFUSION_ALGORITHM_ORIGIN = 0,
|
||||
DIFFUSION_ALGORITHM_ENTROPY_BASED = 1,
|
||||
DIFFUSION_ALGORITHM_MARGIN_BASED = 2,
|
||||
DIFFUSION_ALGORITHM_RANDOM = 3,
|
||||
DIFFUSION_ALGORITHM_CONFIDENCE_BASED = 4,
|
||||
};
|
||||
|
||||
// Unified transfer scheduling methods
|
||||
enum diffusion_transfer_schedule {
|
||||
DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED = 0, // Dream-style: (1.0 - s/t) * remaining
|
||||
DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED = 1, // LLaDA-style: process in blocks with get_num_transfer_tokens
|
||||
};
|
||||
|
||||
typedef bool (*diffusion_step_callback_t)(int32_t step,
|
||||
int32_t total_steps,
|
||||
const llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
void * user_data);
|
||||
|
||||
struct diffusion_params {
|
||||
int32_t steps = 0;
|
||||
float temperature = 0;
|
||||
llama_token mask_token_id = LLAMA_TOKEN_NULL;
|
||||
diffusion_step_callback_t step_callback = nullptr;
|
||||
void * step_callback_user_data = nullptr;
|
||||
int32_t seed = 0;
|
||||
bool visual_mode = false;
|
||||
bool shift_logits = false; // Shift logits by -1 after decode
|
||||
|
||||
float top_p = 0.;
|
||||
int32_t top_k = 0.;
|
||||
|
||||
diffusion_algorithm algorithm = DIFFUSION_ALGORITHM_CONFIDENCE_BASED;
|
||||
diffusion_transfer_schedule schedule = DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED;
|
||||
|
||||
float cfg_scale = 0.; // Config scale for classifier-free guidance
|
||||
float eps = 0.; // Timestep scheduling
|
||||
int32_t block_length = 0; // Block size (for block scheduling)
|
||||
float alg_temp = 0; // algorithm temperature (0.0 = deterministic)
|
||||
bool add_gumbel_noise = false; // Add gumbel noise to the logits if temp > 0.0
|
||||
|
||||
int32_t max_length = 0; // Maximum sequence length
|
||||
};
|
||||
|
||||
void diffusion_generate(llama_context * ctx,
|
||||
const llama_token * input_tokens,
|
||||
llama_token * output_tokens,
|
||||
int32_t n_input,
|
||||
const diffusion_params & params,
|
||||
int32_t & n_generated);
|
||||
@@ -38,8 +38,12 @@ int main(int argc, char ** argv) {
|
||||
std::string result0;
|
||||
std::string result1;
|
||||
std::string result2;
|
||||
std::string result3;
|
||||
|
||||
// init
|
||||
|
||||
ggml_backend_load_all();
|
||||
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
auto * model = llama_init->model();
|
||||
@@ -213,11 +217,83 @@ int main(int argc, char ** argv) {
|
||||
n_past += 1;
|
||||
}
|
||||
|
||||
// test on-device state save/load
|
||||
auto params_ctx4 = common_context_params_to_llama(params);
|
||||
params_ctx4.n_seq_max = 2;
|
||||
llama_context * ctx4 = llama_init_from_model(model, params_ctx4);
|
||||
|
||||
llama_sampler * smpl4 = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl4, llama_sampler_init_dist(params.sampling.seed));
|
||||
|
||||
printf("\nsingle seq run: %s", params.prompt.c_str());
|
||||
|
||||
// load state (rng, logits, embedding and kv_cache) from file
|
||||
n_token_count_out = 0;
|
||||
|
||||
if (!llama_state_load_file(ctx4, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
|
||||
fprintf(stderr, "\n%s : failed to load state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
|
||||
|
||||
// restore state (last tokens)
|
||||
n_past = n_token_count_out;
|
||||
if (!common_replay_last_token(ctx4, tokens.back(), n_past)) {
|
||||
return 1;
|
||||
}
|
||||
++n_past;
|
||||
|
||||
// save seq 0 and load into seq 1
|
||||
{
|
||||
// save kv of seq 0
|
||||
std::vector<uint8_t> seq_store(llama_state_seq_get_size_ext(ctx4, 0, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE));
|
||||
const size_t ncopy = llama_state_seq_get_data_ext(ctx4, seq_store.data(), seq_store.size(), 0, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (ncopy != seq_store.size()) {
|
||||
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
|
||||
return 1;
|
||||
}
|
||||
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
|
||||
|
||||
// erase whole kv
|
||||
llama_memory_clear(llama_get_memory(ctx4), true);
|
||||
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
||||
|
||||
// restore kv into seq 0
|
||||
const size_t nset = llama_state_seq_set_data_ext(ctx4, seq_store.data(), seq_store.size(), 1, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (nset != seq_store.size()) {
|
||||
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
|
||||
return 1;
|
||||
}
|
||||
fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset);
|
||||
}
|
||||
|
||||
// forth run
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
auto next_token = llama_sampler_sample(smpl4, ctx4, -1);
|
||||
auto next_token_str = common_token_to_piece(ctx4, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
result3 += next_token_str;
|
||||
|
||||
common_batch_clear(batch);
|
||||
common_batch_add(batch, next_token, n_past, {1}, true);
|
||||
|
||||
if (llama_decode(ctx4, batch)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_batch_free(batch);
|
||||
return 1;
|
||||
}
|
||||
n_past += 1;
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
|
||||
llama_sampler_free(smpl);
|
||||
llama_sampler_free(smpl2);
|
||||
llama_sampler_free(smpl3);
|
||||
llama_sampler_free(smpl4);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
@@ -226,12 +302,18 @@ int main(int argc, char ** argv) {
|
||||
|
||||
llama_free(ctx2);
|
||||
llama_free(ctx3);
|
||||
llama_free(ctx4);
|
||||
|
||||
if (result0 != result2) {
|
||||
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (result0 != result3) {
|
||||
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n%s : success\n", __func__);
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -110,13 +110,21 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (
|
||||
llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
|
||||
llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
|
||||
llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
|
||||
llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)
|
||||
) {
|
||||
LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
|
||||
if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
|
||||
(llama_vocab_get_add_bos(vocab_tgt) && llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft))) {
|
||||
LOG_ERR("%s: draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
|
||||
__func__,
|
||||
llama_vocab_get_add_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_dft),
|
||||
llama_vocab_bos(vocab_tgt), llama_vocab_bos(vocab_dft));
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
|
||||
(llama_vocab_get_add_eos(vocab_tgt) && llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft))) {
|
||||
LOG_ERR("%s: draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
|
||||
__func__,
|
||||
llama_vocab_get_add_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_dft),
|
||||
llama_vocab_eos(vocab_tgt), llama_vocab_eos(vocab_dft));
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -137,11 +145,12 @@ int main(int argc, char ** argv) {
|
||||
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
||||
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
|
||||
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
|
||||
|
||||
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
||||
LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__);
|
||||
LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i,
|
||||
common_token_to_piece(ctx_tgt, i).c_str(),
|
||||
common_token_to_piece(ctx_dft, i).c_str());
|
||||
common_token_to_piece(vocab_tgt, i).c_str(),
|
||||
common_token_to_piece(vocab_dft, i).c_str());
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,14 +111,14 @@ if [ $GGML_SYCL_DEVICE -ne -1 ]; then
|
||||
echo "Use $GGML_SYCL_DEVICE as main GPU"
|
||||
#use signle GPU only
|
||||
GPUS_SETTING="-mg $GGML_SYCL_DEVICE -sm ${SPLIT_MODE}"
|
||||
export ONEAPI_DEVICE_SELECTOR="level_zero:${$GGML_SYCL_DEVICE}"
|
||||
export ONEAPI_DEVICE_SELECTOR="level_zero:${GGML_SYCL_DEVICE}"
|
||||
echo "ONEAPI_DEVICE_SELECTOR=${ONEAPI_DEVICE_SELECTOR}"
|
||||
else
|
||||
echo "Use all Intel GPUs, including iGPU & dGPU"
|
||||
GPUS_SETTING="-sm ${SPLIT_MODE}"
|
||||
fi
|
||||
|
||||
echo "run cmd: ZES_ENABLE_SYSMAN=1 ${BIN_FILE} -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 200 -e -ngl ${NGL} -s ${SEED} -c ${CONTEXT} ${GPUS_SETTING} -lv ${LOG_VERBOSE} --mmap "
|
||||
echo "run cmd: ZES_ENABLE_SYSMAN=1 ${BIN_FILE} -m ${MODEL_FILE} -ngl ${NGL} -s ${SEED} -c ${CONTEXT} ${GPUS_SETTING} -lv ${LOG_VERBOSE} --mmap --host 0.0.0.0 --port 8000"
|
||||
ZES_ENABLE_SYSMAN=1 ${BIN_FILE} -m ${MODEL_FILE} -ngl ${NGL} -s ${SEED} -c ${CONTEXT} ${GPUS_SETTING} -lv ${LOG_VERBOSE} --mmap --host 0.0.0.0 --port 8000
|
||||
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ if [ $GGML_SYCL_DEVICE -ne -1 ]; then
|
||||
echo "Use $GGML_SYCL_DEVICE as main GPU"
|
||||
#use signle GPU only
|
||||
GPUS_SETTING="-mg $GGML_SYCL_DEVICE -sm ${SPLIT_MODE}"
|
||||
export ONEAPI_DEVICE_SELECTOR="level_zero:${$GGML_SYCL_DEVICE}"
|
||||
export ONEAPI_DEVICE_SELECTOR="level_zero:${GGML_SYCL_DEVICE}"
|
||||
echo "ONEAPI_DEVICE_SELECTOR=${ONEAPI_DEVICE_SELECTOR}"
|
||||
else
|
||||
echo "Use all Intel GPUs, including iGPU & dGPU"
|
||||
|
||||
@@ -4,8 +4,8 @@ project("ggml" C CXX ASM)
|
||||
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 10)
|
||||
set(GGML_VERSION_PATCH 1)
|
||||
set(GGML_VERSION_MINOR 11)
|
||||
set(GGML_VERSION_PATCH 0)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
|
||||
@@ -169,7 +169,7 @@ extern "C" {
|
||||
// device type
|
||||
enum ggml_backend_dev_type type;
|
||||
// device id
|
||||
// for PCI devices, this should be the PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:01:00.0")
|
||||
// for PCI devices, this should be the lower-case PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:c1:00.0")
|
||||
// if the id is unknown, this should be NULL
|
||||
const char * device_id;
|
||||
// device capabilities
|
||||
|
||||
@@ -438,6 +438,12 @@ extern "C" {
|
||||
GGML_PREC_F32 = 10,
|
||||
};
|
||||
|
||||
// op hint
|
||||
enum ggml_op_hint {
|
||||
GGML_HINT_NONE = 0,
|
||||
GGML_HINT_SRC0_IS_HADAMARD = 1,
|
||||
};
|
||||
|
||||
// model file types
|
||||
enum ggml_ftype {
|
||||
GGML_FTYPE_UNKNOWN = -1,
|
||||
@@ -1419,6 +1425,11 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_prec prec);
|
||||
|
||||
// change the hint of a matrix multiplication
|
||||
GGML_API void ggml_mul_mat_set_hint(
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_op_hint hint);
|
||||
|
||||
// indirect matrix multiplication
|
||||
GGML_API struct ggml_tensor * ggml_mul_mat_id(
|
||||
struct ggml_context * ctx,
|
||||
|
||||
@@ -2100,8 +2100,8 @@ static const ggml_backend_i ggml_backend_meta_i = {
|
||||
/* .free = */ ggml_backend_meta_free,
|
||||
/* .set_tensor_async = */ ggml_backend_meta_set_tensor_async,
|
||||
/* .get_tensor_async = */ ggml_backend_meta_get_tensor_async,
|
||||
/* .get_tensor_2d_async = */ nullptr,
|
||||
/* .set_tensor_2d_async = */ nullptr,
|
||||
/* .get_tensor_2d_async = */ nullptr,
|
||||
/* .cpy_tensor_async = */ nullptr,
|
||||
/* .synchronize = */ ggml_backend_meta_synchronize,
|
||||
/* .graph_plan_create = */ nullptr,
|
||||
|
||||
@@ -965,7 +965,7 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
|
||||
}
|
||||
if (sched->debug > 1) {
|
||||
ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
|
||||
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_name(node->op), node->name,
|
||||
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_desc(node), node->name,
|
||||
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node),
|
||||
graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)], node->flags & GGML_TENSOR_FLAG_COMPUTE ? 1 : 0);
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
|
||||
@@ -262,9 +262,9 @@ static struct ggml_backend_i blas_backend_i = {
|
||||
/* .get_name = */ ggml_backend_blas_get_name,
|
||||
/* .free = */ ggml_backend_blas_free,
|
||||
/* .set_tensor_async = */ NULL,
|
||||
/* .get_tensor_2d_async = */ NULL,
|
||||
/* .set_tensor_2d_async = */ NULL,
|
||||
/* .get_tensor_async = */ NULL,
|
||||
/* .set_tensor_2d_async = */ NULL,
|
||||
/* .get_tensor_2d_async = */ NULL,
|
||||
/* .cpy_tensor_async = */ NULL,
|
||||
/* .synchronize = */ NULL,
|
||||
/* .graph_plan_create = */ NULL,
|
||||
|
||||
@@ -2746,8 +2746,8 @@ static const ggml_backend_i ggml_backend_cann_interface = {
|
||||
/* .free = */ ggml_backend_cann_free,
|
||||
/* .set_tensor_async = */ ggml_backend_cann_set_tensor_async,
|
||||
/* .get_tensor_async = */ ggml_backend_cann_get_tensor_async,
|
||||
/* .get_tensor_2d_async = */ NULL,
|
||||
/* .set_tensor_2d_async = */ NULL,
|
||||
/* .get_tensor_2d_async = */ NULL,
|
||||
/* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async,
|
||||
/* .synchronize = */ ggml_backend_cann_synchronize,
|
||||
/* .graph_plan_create = */ NULL,
|
||||
|
||||
@@ -578,13 +578,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
|
||||
# Fetch KleidiAI sources:
|
||||
include(FetchContent)
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.22.0")
|
||||
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "54049037570ab0ee0a0d126b2ba5ece1")
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.24.0")
|
||||
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/releases/download/${KLEIDIAI_COMMIT_TAG}/kleidiai-${KLEIDIAI_COMMIT_TAG}-src.tar.gz")
|
||||
set(KLEIDIAI_RELEASE_ARCHIVE_MD5 "2f02ebe29573d45813e671eb304f2a00")
|
||||
|
||||
set(KLEIDIAI_FETCH_ARGS
|
||||
URL ${KLEIDIAI_DOWNLOAD_URL}
|
||||
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5}
|
||||
URL_HASH MD5=${KLEIDIAI_RELEASE_ARCHIVE_MD5}
|
||||
)
|
||||
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
|
||||
list(APPEND KLEIDIAI_FETCH_ARGS DOWNLOAD_EXTRACT_TIMESTAMP NEW)
|
||||
|
||||
@@ -203,7 +203,6 @@
|
||||
#elif defined(__riscv)
|
||||
// quants.c
|
||||
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
|
||||
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
|
||||
@@ -480,6 +480,104 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__riscv_v)
|
||||
static NOINLINE void ggml_vec_dot_q1_0_q8_0_vl256(const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy) {
|
||||
const int qk = QK1_0;
|
||||
const int nb = n / qk;
|
||||
assert(n % qk == 0);
|
||||
|
||||
const block_q1_0 * GGML_RESTRICT x = vx;
|
||||
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||
|
||||
//LMUL = 1, VLMAX = 32
|
||||
const size_t vl32 = __riscv_vsetvl_e8m1(32);
|
||||
assert(vl32 == 32);
|
||||
|
||||
const vint16m1_t zero = __riscv_vmv_v_x_i16m1(0, 1);
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d);
|
||||
|
||||
float acc = 0;
|
||||
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
const block_q8_0 * GGML_RESTRICT yb = &y[ib * 4 + k];
|
||||
const vbool8_t is_not_zero = __riscv_vlm_v_b8(x[ib].qs + 4 * k, vl32);
|
||||
|
||||
const vint8m1_t qy = __riscv_vle8_v_i8m1(yb->qs, vl32);
|
||||
const vint8m1_t neg_qy = __riscv_vneg_v_i8m1(qy, vl32);
|
||||
const vint8m1_t sy = __riscv_vmerge_vvm_i8m1(neg_qy, qy, is_not_zero, vl32);
|
||||
|
||||
const vint16m1_t red = __riscv_vwredsum_vs_i8m1_i16m1(sy, zero, vl32);
|
||||
acc += GGML_CPU_FP16_TO_FP32(yb->d) * (float)__riscv_vmv_x_s_i16m1_i16(red);
|
||||
}
|
||||
|
||||
sumf += d0 * acc;
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
|
||||
static NOINLINE void ggml_vec_dot_q1_0_q8_0_vl128(const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy) {
|
||||
const int qk = QK1_0;
|
||||
const int nb = n / qk;
|
||||
assert(n % qk == 0);
|
||||
|
||||
const block_q1_0 * GGML_RESTRICT x = vx;
|
||||
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||
|
||||
//LMUL = 2, VLMAX = 32
|
||||
const size_t vl32 = __riscv_vsetvl_e8m2(32);
|
||||
assert(vl32 == 32);
|
||||
|
||||
const vint16m1_t zero = __riscv_vmv_v_x_i16m1(0, 1);
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d);
|
||||
|
||||
float acc = 0;
|
||||
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
const block_q8_0 * GGML_RESTRICT yb = &y[ib * 4 + k];
|
||||
const vbool4_t is_not_zero = __riscv_vlm_v_b4(x[ib].qs + 4 * k, vl32);
|
||||
|
||||
const vint8m2_t qy = __riscv_vle8_v_i8m2(yb->qs, vl32);
|
||||
const vint8m2_t neg_qy =__riscv_vneg_v_i8m2(qy, vl32);
|
||||
const vint8m2_t sy = __riscv_vmerge_vvm_i8m2(neg_qy, qy, is_not_zero, vl32);
|
||||
|
||||
const vint16m1_t red = __riscv_vwredsum_vs_i8m2_i16m1(sy, zero, vl32);
|
||||
acc += GGML_CPU_FP16_TO_FP32(yb->d) * (float)__riscv_vmv_x_s_i16m1_i16(red);
|
||||
}
|
||||
|
||||
sumf += d0 * acc;
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined(__riscv_v)
|
||||
assert(nrc == 1);
|
||||
|
||||
const size_t vlen_bits = __riscv_vlenb() * 8;
|
||||
|
||||
if (vlen_bits >= 256) {
|
||||
ggml_vec_dot_q1_0_q8_0_vl256(n, s, vx, vy);
|
||||
} else if (vlen_bits >= 128) {
|
||||
ggml_vec_dot_q1_0_q8_0_vl128(n, s, vx, vy);
|
||||
} else {
|
||||
ggml_vec_dot_q1_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
#else
|
||||
ggml_vec_dot_q1_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
|
||||
@@ -1245,6 +1245,12 @@ void ggml_compute_forward_mul_mat(
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
const int32_t hint = ggml_get_op_params_i32(dst, 1);
|
||||
if (hint == GGML_HINT_SRC0_IS_HADAMARD && !params->use_ref) {
|
||||
ggml_compute_forward_fwht(params, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int ith = params->ith;
|
||||
@@ -2959,6 +2965,45 @@ struct ggml_cplan ggml_graph_plan(
|
||||
return cplan;
|
||||
}
|
||||
|
||||
|
||||
// Try to fuse the current node with subsequent nodes for better performance.
|
||||
// Returns the number of nodes skipped by fusion (>=1), or 0 if no fusion was applied.
|
||||
static bool ggml_cpu_disable_fusion = false; // initialized once in ggml_cpu_init(), read-only afterwards
|
||||
|
||||
static int ggml_cpu_try_fuse_ops(
|
||||
const struct ggml_cgraph * cgraph,
|
||||
const int node_n,
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_cplan * cplan) {
|
||||
|
||||
if (ggml_cpu_disable_fusion || cplan->use_ref) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct ggml_tensor * node = cgraph->nodes[node_n];
|
||||
|
||||
if (node->op == GGML_OP_RMS_NORM) {
|
||||
// RMS_NORM + MUL fusion
|
||||
const enum ggml_op fuse_ops[] = { GGML_OP_RMS_NORM, GGML_OP_MUL };
|
||||
if (ggml_can_fuse(cgraph, node_n, fuse_ops, 2)) {
|
||||
struct ggml_tensor * mul_node = cgraph->nodes[node_n + 1];
|
||||
const struct ggml_tensor * mul_w = (mul_node->src[0] == node)
|
||||
? mul_node->src[1] : mul_node->src[0];
|
||||
if (node->src[0]->type == GGML_TYPE_F32 &&
|
||||
mul_node->type == GGML_TYPE_F32 &&
|
||||
mul_w->type == GGML_TYPE_F32 &&
|
||||
mul_w->ne[0] == node->ne[0] &&
|
||||
mul_w->nb[0] == sizeof(float)) {
|
||||
|
||||
ggml_compute_forward_rms_norm_mul_fused(params, node, mul_node);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
|
||||
struct ggml_threadpool * tp = state->threadpool;
|
||||
@@ -2995,7 +3040,14 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_compute_forward(¶ms, node);
|
||||
// TODO: move fused-op detection into ggml_graph_plan so fusion decisions are made once at planning time
|
||||
// Try fused ops, fall back to normal compute
|
||||
const int n_fused = ggml_cpu_try_fuse_ops(cgraph, node_n, ¶ms, cplan);
|
||||
if (n_fused > 0) {
|
||||
node_n += n_fused;
|
||||
} else {
|
||||
ggml_compute_forward(¶ms, node);
|
||||
}
|
||||
|
||||
if (state->ith == 0 && cplan->abort_callback &&
|
||||
cplan->abort_callback(cplan->abort_callback_data)) {
|
||||
@@ -3757,6 +3809,11 @@ void ggml_cpu_init(void) {
|
||||
ggml_init_riscv_arch_features();
|
||||
#endif
|
||||
|
||||
{
|
||||
const char * env = getenv("GGML_CPU_DISABLE_FUSION");
|
||||
ggml_cpu_disable_fusion = (env != NULL && atoi(env) == 1);
|
||||
}
|
||||
|
||||
is_first_call = false;
|
||||
}
|
||||
|
||||
|
||||
@@ -195,8 +195,8 @@ static const struct ggml_backend_i ggml_backend_cpu_i = {
|
||||
/* .free = */ ggml_backend_cpu_free,
|
||||
/* .set_tensor_async = */ NULL,
|
||||
/* .get_tensor_async = */ NULL,
|
||||
/* .get_tensor_2d_async = */ NULL,
|
||||
/* .set_tensor_2d_async = */ NULL,
|
||||
/* .get_tensor_2d_async = */ NULL,
|
||||
/* .cpy_tensor_async = */ NULL,
|
||||
/* .synchronize = */ NULL,
|
||||
/* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create,
|
||||
|
||||
@@ -3713,11 +3713,27 @@ void ggml_compute_forward_norm(
|
||||
|
||||
// ggml_compute_forward_group_rms_norm
|
||||
|
||||
// fusion kinds that can be combined with the rms_norm computation in a single pass.
|
||||
// extend this enum when adding new fused variants (e.g. FUSE_ADD, FUSE_MUL_ADD, ...).
|
||||
enum ggml_rms_norm_fuse_op {
|
||||
GGML_RMS_NORM_FUSE_OP_NONE,
|
||||
GGML_RMS_NORM_FUSE_OP_MUL,
|
||||
};
|
||||
|
||||
template <ggml_rms_norm_fuse_op FUSE_OP>
|
||||
static void ggml_compute_forward_rms_norm_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
ggml_tensor * dst_rms_norm,
|
||||
ggml_tensor * dst_fused = nullptr) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src0 = dst_rms_norm->src[0];
|
||||
const ggml_tensor * src1 = nullptr;
|
||||
ggml_tensor * dst = dst_rms_norm;
|
||||
|
||||
if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) {
|
||||
src1 = (dst_fused->src[0] == dst_rms_norm) ? dst_fused->src[1] : dst_fused->src[0];
|
||||
dst = dst_fused;
|
||||
}
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
@@ -3726,11 +3742,10 @@ static void ggml_compute_forward_rms_norm_f32(
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
memcpy(&eps, dst_rms_norm->op_params, sizeof(float));
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
// TODO: optimize
|
||||
@@ -3740,25 +3755,32 @@ static void ggml_compute_forward_rms_norm_f32(
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
// worth switching to explicit SIMD?
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
sum += (ggml_float)(x[i00] * x[i00]);
|
||||
}
|
||||
|
||||
const float mean = sum/ne00;
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
memcpy(y, x, ne00 * sizeof(float));
|
||||
// for (int i00 = 0; i00 < ne00; i00++) {
|
||||
// y[i00] = x[i00];
|
||||
// }
|
||||
|
||||
const float mean = sum/ne00;
|
||||
const float scale = 1.0f/sqrtf(mean + eps);
|
||||
|
||||
// if you hit this, likely you got an inf somewhere earlier
|
||||
assert(scale > 0.0f);
|
||||
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) {
|
||||
const int64_t i11 = i01 % ne11;
|
||||
const int64_t i12 = i02 % ne12;
|
||||
const int64_t i13 = i03 % ne13;
|
||||
const float * w = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
|
||||
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
y[i00] = x[i00] * scale * w[i00];
|
||||
}
|
||||
} else {
|
||||
memcpy(y, x, ne00 * sizeof(float));
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3773,7 +3795,31 @@ void ggml_compute_forward_rms_norm(
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_rms_norm_f32(params, dst);
|
||||
ggml_compute_forward_rms_norm_f32<GGML_RMS_NORM_FUSE_OP_NONE>(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fused RMS_NORM + MUL: computes dst = rms_norm(src0) * src1 in a single pass.
|
||||
// This avoids materializing the intermediate rms_norm result in memory.
|
||||
void ggml_compute_forward_rms_norm_mul_fused(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst_rms_norm,
|
||||
ggml_tensor * dst_mul) {
|
||||
|
||||
GGML_ASSERT(dst_mul != nullptr);
|
||||
GGML_ASSERT(dst_mul->src[0] == dst_rms_norm || dst_mul->src[1] == dst_rms_norm);
|
||||
|
||||
const ggml_tensor * src0 = dst_rms_norm->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_rms_norm_f32<GGML_RMS_NORM_FUSE_OP_MUL>(params, dst_rms_norm, dst_mul);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
@@ -11212,3 +11258,91 @@ void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_fwht_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t n = ne10;
|
||||
GGML_ASSERT((n & (n - 1)) == 0); // must be power of 2
|
||||
|
||||
const int64_t nr = ne11 * ne12 * ne13;
|
||||
const int64_t rows_per_thread = (nr + nth - 1) / nth;
|
||||
const int64_t start_row = ith * rows_per_thread;
|
||||
const int64_t end_row = MIN(start_row + rows_per_thread, nr);
|
||||
|
||||
const float scale = 1.0f / sqrtf((float)n);
|
||||
|
||||
#if defined(GGML_SIMD)
|
||||
const GGML_F32_VEC v_minus_one = GGML_F32_VEC_SET1(-1.0f);
|
||||
#endif
|
||||
|
||||
for (int64_t r = start_row; r < end_row; r++) {
|
||||
const int64_t i13 = r / (ne11 * ne12);
|
||||
const int64_t i12 = (r - i13 * ne11 * ne12) / ne11;
|
||||
const int64_t i11 = r - i13 * ne11 * ne12 - i12 * ne11;
|
||||
|
||||
const float * src_row = (const float *) ((const char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13);
|
||||
float * dst_row = (float *) ((char *) dst->data + i11 * nb1 + i12 * nb2 + i13 * nb3);
|
||||
|
||||
for (int64_t j = 0; j < n; j++) {
|
||||
dst_row[j] = src_row[j] * scale;
|
||||
}
|
||||
|
||||
// Scalar passes
|
||||
#if defined(GGML_SIMD)
|
||||
const int step = GGML_F32_EPR;
|
||||
#else
|
||||
const int step = n;
|
||||
#endif
|
||||
for (int64_t len = 1; len < step && len < n; len <<= 1) {
|
||||
for (int64_t i = 0; i < n; i += 2 * len) {
|
||||
for (int64_t j = 0; j < len; j++) {
|
||||
float u = dst_row[i + j];
|
||||
float v = dst_row[i + len + j];
|
||||
dst_row[i + j] = u + v;
|
||||
dst_row[i + len + j] = u - v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SIMD passes using GGML_F32_VEC_* macros for multi-architecture support
|
||||
#if defined(GGML_SIMD)
|
||||
for (int64_t len = step; len < n; len <<= 1) {
|
||||
for (int64_t i = 0; i < n; i += 2 * len) {
|
||||
for (int64_t j = 0; j < len; j += step) {
|
||||
GGML_F32_VEC u = GGML_F32_VEC_LOAD(dst_row + i + j);
|
||||
GGML_F32_VEC v = GGML_F32_VEC_LOAD(dst_row + i + len + j);
|
||||
|
||||
GGML_F32_VEC_STORE(dst_row + i + j, GGML_F32_VEC_ADD(u, v));
|
||||
GGML_F32_VEC_STORE(dst_row + i + len + j, GGML_F32_VEC_FMA(u, v, v_minus_one));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_fwht(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_fwht_f32(params, dst);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error - fwht is F32 only");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,6 +44,7 @@ void ggml_compute_forward_concat(const struct ggml_compute_params * params, stru
|
||||
void ggml_compute_forward_silu_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_rms_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_rms_norm_mul_fused(const struct ggml_compute_params * params, struct ggml_tensor * dst_rms_norm, struct ggml_tensor * dst_mul);
|
||||
void ggml_compute_forward_rms_norm_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_group_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_l2_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
@@ -111,6 +112,7 @@ void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params *
|
||||
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_fwht(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -61,6 +61,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 64, 4, 64, 96, 64, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 64, 4, 32, 96, 64, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 128, 2, 32, 96, 64, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 128, 2, 32, 96, 64, 64, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
@@ -1561,6 +1566,10 @@ static __global__ void flash_attn_ext_f16(
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
if (DKQ == 192 && ncols2 != 8 && ncols2 != 16) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
if (ncols1*ncols2 < 32) {
|
||||
NO_DEVICE_CODE;
|
||||
|
||||
@@ -34,6 +34,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst);
|
||||
} break;
|
||||
case 192: {
|
||||
GGML_ASSERT(V->ne[0] == 128);
|
||||
ggml_cuda_flash_attn_ext_tile_case<192, 128>(ctx, dst);
|
||||
} break;
|
||||
case 256: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
|
||||
|
||||
@@ -62,13 +62,19 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 64, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
@@ -124,13 +130,19 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 128, 3, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 3, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64)
|
||||
@@ -193,6 +205,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 256, 2, 128, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128)
|
||||
@@ -264,6 +282,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 64, 8, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 6, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 128, 6, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 5, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 3, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256)
|
||||
@@ -1124,7 +1148,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
|
||||
#ifdef GGML_USE_HIP
|
||||
if constexpr (DV <= 128) {
|
||||
if constexpr (DKQ <= 128) {
|
||||
if (Q->ne[1] > 32/ncols2) {
|
||||
constexpr int cols_per_block = 64;
|
||||
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
@@ -1138,7 +1162,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
|
||||
#endif // GGML_USE_HIP
|
||||
|
||||
#ifndef GGML_USE_HIP
|
||||
if constexpr (DV <= 256)
|
||||
if constexpr (DKQ <= 256)
|
||||
#endif // GGML_USE_HIP
|
||||
{
|
||||
if (Q->ne[1] > 16/ncols2) {
|
||||
@@ -1220,11 +1244,22 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
if constexpr (DKQ == 320) { // Mistral Small 4
|
||||
if constexpr (DKQ == 320) {
|
||||
// This branch is only used for Mistral Small 4 which has a GQA ratio of 32.
|
||||
// On AMD, simply use that GQA ratio with 32 columns / block since we always have enough SRAM.
|
||||
// On NVIDIA however, the tile kernel is only used for GPUs that can't use the mma kernel (Pascal and older).
|
||||
// Therefore, use a GQA ratio of 16 with 16 columns / block to stay below 48 kiB of SRAM / block.
|
||||
#ifdef GGML_USE_HIP
|
||||
if (use_gqa_opt && gqa_ratio % 32 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 32, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
#else
|
||||
if (use_gqa_opt && gqa_ratio % 16 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
#endif // GGML_USE_HIP
|
||||
GGML_ABORT("flash-attn tile (320/256): expected GQA ratio multiple of 32");
|
||||
}
|
||||
|
||||
@@ -1239,7 +1274,20 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (DKQ <= 512 && DKQ != 320) {
|
||||
if constexpr (DKQ == 192) {
|
||||
// MiMo-V2.5 / V2.5-Pro / V2-Flash: gqa_ratio is 8 (SWA) or 16 (full attn)
|
||||
if (use_gqa_opt && gqa_ratio % 16 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
GGML_ABORT("flash-attn tile (192/128): expected GQA ratio multiple of 8");
|
||||
}
|
||||
|
||||
if constexpr (DKQ <= 512 && DKQ != 320 && DKQ != 192) {
|
||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
@@ -1292,6 +1340,7 @@ extern DECL_FATTN_TILE_CASE( 80, 80);
|
||||
extern DECL_FATTN_TILE_CASE( 96, 96);
|
||||
extern DECL_FATTN_TILE_CASE(112, 112);
|
||||
extern DECL_FATTN_TILE_CASE(128, 128);
|
||||
extern DECL_FATTN_TILE_CASE(192, 128);
|
||||
extern DECL_FATTN_TILE_CASE(256, 256);
|
||||
extern DECL_FATTN_TILE_CASE(320, 256);
|
||||
extern DECL_FATTN_TILE_CASE(512, 512);
|
||||
|
||||
@@ -139,6 +139,22 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
GGML_ASSERT(V->ne[0] == 128);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst);
|
||||
break;
|
||||
case 192: {
|
||||
// MiMo-V2.5 / V2.5-Pro / V2-Flash: gqa_ratio is 8 (SWA) or 16 (full attn)
|
||||
GGML_ASSERT(V->ne[0] == 128);
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f;
|
||||
GGML_ASSERT(use_gqa_opt);
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
if (gqa_ratio % 16 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 16>(ctx, dst);
|
||||
} else {
|
||||
GGML_ASSERT(gqa_ratio % 8 == 0);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 8>(ctx, dst);
|
||||
}
|
||||
} break;
|
||||
case 256:
|
||||
GGML_ASSERT(V->ne[0] == 256);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
|
||||
@@ -368,6 +384,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
case 192:
|
||||
if (V->ne[0] != 128 || !gqa_opt_applies) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (gqa_ratio % 8 != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
case 320:
|
||||
if (V->ne[0] != 256 || !gqa_opt_applies) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
@@ -425,7 +449,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
}
|
||||
|
||||
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
// 192 satisfies % 64 == 0 but has no vec instance (DKQ != DV); force it onto the MMA path.
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && Q->ne[0] != 192 && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
// If Turing tensor cores are available, use them:
|
||||
if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
|
||||
@@ -454,7 +479,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
|
||||
if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
|
||||
int gqa_ratio_eff = 1;
|
||||
const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
|
||||
const int ncols2_max = (Q->ne[0] == 576 || Q->ne[0] == 192) ? 16 : 8;
|
||||
while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
|
||||
gqa_ratio_eff *= 2;
|
||||
}
|
||||
@@ -468,7 +493,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
}
|
||||
|
||||
// Use the WMMA kernel if possible:
|
||||
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 512 && Q->ne[0] != 576) {
|
||||
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 192 && Q->ne[0] != 512 && Q->ne[0] != 576) {
|
||||
if (can_use_vector_kernel && Q->ne[1] <= 2) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
@@ -501,7 +526,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
}
|
||||
|
||||
// Use MFMA flash attention for CDNA (MI100+):
|
||||
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) {
|
||||
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 192 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) {
|
||||
const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1);
|
||||
// MMA vs tile crossover benchmarked on MI300X @ d32768:
|
||||
// hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%)
|
||||
|
||||
@@ -6,17 +6,18 @@ template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static __global__ void k_get_rows(
|
||||
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
|
||||
/*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
|
||||
/*const int64_t ne10,*/ const int64_t ne11, const uint3 ne12_fdv, /*const int64_t ne13,*/
|
||||
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
|
||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||
|
||||
for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
|
||||
for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) {
|
||||
for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
const int i10 = blockIdx.x;
|
||||
const int i11 = z / ne12; // TODO fastdiv
|
||||
const int i12 = z % ne12;
|
||||
const uint2 dm = fast_div_modulo((uint32_t)z, ne12_fdv);
|
||||
const int i11 = dm.x;
|
||||
const int i12 = dm.y;
|
||||
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
@@ -42,17 +43,18 @@ template<typename src0_t, typename dst_t>
|
||||
static __global__ void k_get_rows_float(
|
||||
const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
|
||||
/*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
|
||||
/*const int64_t ne10,*/ const int64_t ne11, const uint3 ne12_fdv, /*const int64_t ne13,*/
|
||||
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
|
||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||
|
||||
for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
|
||||
for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) {
|
||||
for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
const int i10 = blockIdx.x;
|
||||
const int i11 = z / ne12; // TODO fastdiv
|
||||
const int i12 = z % ne12;
|
||||
const uint2 dm = fast_div_modulo((uint32_t)z, ne12_fdv);
|
||||
const int i11 = dm.x;
|
||||
const int i12 = dm.y;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
@@ -115,10 +117,14 @@ static void get_rows_cuda_q(
|
||||
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
|
||||
GGML_ASSERT(ne12 > 0);
|
||||
GGML_ASSERT(ne11 <= std::numeric_limits<uint32_t>::max() / ne12);
|
||||
const uint3 ne12_fdv = init_fastdiv_values(ne12);
|
||||
|
||||
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, /*ne01, ne02, ne03,*/
|
||||
/*ne10,*/ ne11, ne12, /*ne13,*/
|
||||
/*ne10,*/ ne11, ne12_fdv, /*ne13,*/
|
||||
/* s0,*/ s1, s2, s3,
|
||||
/* nb00,*/ nb01, nb02, nb03,
|
||||
s10, s11, s12/*, s13*/);
|
||||
@@ -146,10 +152,14 @@ static void get_rows_cuda_float(
|
||||
const size_t s12 = nb12 / sizeof(int32_t);
|
||||
// const size_t s13 = nb13 / sizeof(int32_t);
|
||||
|
||||
GGML_ASSERT(ne12 > 0);
|
||||
GGML_ASSERT(ne11 <= std::numeric_limits<uint32_t>::max() / ne12);
|
||||
const uint3 ne12_fdv = init_fastdiv_values(ne12);
|
||||
|
||||
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, /*ne01, ne02, ne03,*/
|
||||
/*ne10,*/ ne11, ne12, /*ne13,*/
|
||||
/*ne10,*/ ne11, ne12_fdv, /*ne13,*/
|
||||
/* s0,*/ s1, s2, s3,
|
||||
/* nb00,*/ nb01, nb02, nb03,
|
||||
s10, s11, s12/*, s13*/);
|
||||
|
||||
@@ -39,6 +39,7 @@
|
||||
#include "ggml-cuda/rope.cuh"
|
||||
#include "ggml-cuda/roll.cuh"
|
||||
#include "ggml-cuda/scale.cuh"
|
||||
#include "ggml-cuda/snake.cuh"
|
||||
#include "ggml-cuda/softcap.cuh"
|
||||
#include "ggml-cuda/softmax.cuh"
|
||||
#include "ggml-cuda/ssm-conv.cuh"
|
||||
@@ -3757,6 +3758,35 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
|
||||
return 2;
|
||||
}
|
||||
|
||||
// Snake activation: y = x + sin(a*x)^2 * inv_b
|
||||
// Naive 5-op decomposition emitted by frontends: mul -> sin -> sqr -> mul -> add
|
||||
if (ggml_can_fuse_subgraph(cgraph, i,
|
||||
{ GGML_OP_MUL, GGML_OP_SIN, GGML_OP_SQR, GGML_OP_MUL, GGML_OP_ADD },
|
||||
{ i + 4 })) {
|
||||
const ggml_tensor * mul0 = cgraph->nodes[i];
|
||||
const ggml_tensor * sqr = cgraph->nodes[i + 2];
|
||||
const ggml_tensor * mul1 = cgraph->nodes[i + 3];
|
||||
ggml_tensor * add = cgraph->nodes[i + 4];
|
||||
|
||||
// x carries the full activation shape, a is the broadcast operand
|
||||
const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1];
|
||||
const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0];
|
||||
|
||||
// mul1 reads sqr and inv_b in either operand order
|
||||
const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0];
|
||||
|
||||
// closure check: the trailing add must read the same x as the leading mul
|
||||
const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0];
|
||||
|
||||
const bool type_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16);
|
||||
const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1];
|
||||
|
||||
if (type_ok && shape_ok && x_in_add == x && add->type == x->type) {
|
||||
ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
// multi-(add or mul)
|
||||
if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) {
|
||||
int n_fuse = 0;
|
||||
@@ -4588,8 +4618,8 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
|
||||
/* .free = */ ggml_backend_cuda_free,
|
||||
/* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async,
|
||||
/* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async,
|
||||
/* .get_tensor_2d_async = */ ggml_backend_cuda_set_tensor_2d_async,
|
||||
/* .set_tensor_2d_async = */ ggml_backend_cuda_get_tensor_2d_async,
|
||||
/* .set_tensor_2d_async = */ ggml_backend_cuda_set_tensor_2d_async,
|
||||
/* .get_tensor_2d_async = */ ggml_backend_cuda_get_tensor_2d_async,
|
||||
/* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async,
|
||||
/* .synchronize = */ ggml_backend_cuda_synchronize,
|
||||
/* .graph_plan_create = */ NULL,
|
||||
@@ -5431,9 +5461,12 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||
dev_ctx->description = prop.name;
|
||||
|
||||
char pci_bus_id[16] = {};
|
||||
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
|
||||
char pci_bus_id[32] = {};
|
||||
CUDA_CHECK(cudaDeviceGetPCIBusId(pci_bus_id, sizeof(pci_bus_id), i));
|
||||
dev_ctx->pci_bus_id = pci_bus_id;
|
||||
for (char & c : dev_ctx->pci_bus_id) {
|
||||
c = std::tolower(c);
|
||||
}
|
||||
dev_ctx->op_offload_min_batch_size = min_batch_size;
|
||||
|
||||
ggml_backend_dev_t dev = new ggml_backend_device {
|
||||
|
||||
@@ -54,15 +54,31 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const int64_t dps2 = ne2 / ne02;
|
||||
const int64_t dps3 = ne3 / ne03;
|
||||
|
||||
// TODO batched matrix multiplication
|
||||
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||||
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
||||
if (dps2 == 1 && ne2 > 1) {
|
||||
// src0 has uniform stride s02 along dim 2; batch the inner loop with a strided GEMM
|
||||
GGML_ASSERT(ne2 <= std::numeric_limits<int>::max());
|
||||
const int batch_count = (int) ne2;
|
||||
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
cublasSgemmStridedBatched(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
|
||||
src1_d + i3 *s13 + i2 *s12, ldb,
|
||||
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
|
||||
&alpha, src0_d + (i3/dps3)*s03, lda, s02,
|
||||
src1_d + i3 *s13, ldb, s12,
|
||||
&beta, dst_d + i3 *s3, ldc, s2,
|
||||
batch_count));
|
||||
}
|
||||
} else {
|
||||
// Fallback: ne2 == 1 (no batching benefit) or dps2 > 1 (src0 broadcast along dim 2
|
||||
// with non-uniform stride; would need cublasSgemmBatched with pointer arrays).
|
||||
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||||
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
|
||||
src1_d + i3 *s13 + i2 *s12, ldb,
|
||||
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
72
ggml/src/ggml-cuda/snake.cu
Normal file
72
ggml/src/ggml-cuda/snake.cu
Normal file
@@ -0,0 +1,72 @@
|
||||
#include "snake.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
// Fused Snake activation: y = x + sin^2(a * x) * inv_b
|
||||
// x: [T, C] (T contiguous), a: [1, C], inv_b: [1, C]
|
||||
// Supports F32, F16, BF16 data with F32 compute.
|
||||
|
||||
template <typename T>
|
||||
static __global__ void snake_kernel(
|
||||
const T * __restrict__ x,
|
||||
const float * __restrict__ a,
|
||||
const float * __restrict__ inv_b,
|
||||
T * __restrict__ dst,
|
||||
const int total,
|
||||
const uint3 T_len_fastdiv) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= total) return;
|
||||
|
||||
const int c = (int) fastdiv((uint32_t) idx, T_len_fastdiv);
|
||||
|
||||
const float xi = ggml_cuda_cast<float>(x[idx]);
|
||||
const float s = sinf(a[c] * xi);
|
||||
dst[idx] = ggml_cuda_cast<T>(xi + s * s * inv_b[c]);
|
||||
}
|
||||
|
||||
// Internal launcher with explicit x/a/inv_b/dst tensors.
|
||||
// Shared by the public op (reads dst->src) and the fusion path (explicit args).
|
||||
static void launch_snake(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * x,
|
||||
const ggml_tensor * a,
|
||||
const ggml_tensor * inv_b,
|
||||
ggml_tensor * dst) {
|
||||
const float * a_d = (const float *)a->data;
|
||||
const float * inv_b_d = (const float *)inv_b->data;
|
||||
|
||||
const int T = (int)x->ne[0];
|
||||
const int C = (int)x->ne[1];
|
||||
const int total = T * C;
|
||||
const uint3 T_len_fastdiv = init_fastdiv_values((uint64_t) T);
|
||||
|
||||
const int block_size = 256;
|
||||
const int grid_size = (total + block_size - 1) / block_size;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
switch (x->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
snake_kernel<<<grid_size, block_size, 0, stream>>>(
|
||||
(const float *)x->data, a_d, inv_b_d, (float *)dst->data, total, T_len_fastdiv);
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
snake_kernel<<<grid_size, block_size, 0, stream>>>(
|
||||
(const half *)x->data, a_d, inv_b_d, (half *)dst->data, total, T_len_fastdiv);
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
snake_kernel<<<grid_size, block_size, 0, stream>>>(
|
||||
(const nv_bfloat16 *)x->data, a_d, inv_b_d, (nv_bfloat16 *)dst->data, total, T_len_fastdiv);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("snake: unsupported type");
|
||||
}
|
||||
}
|
||||
|
||||
// Fusion entry: caller supplies x/a/inv_b explicitly from the matched
|
||||
// mul -> sin -> sqr -> mul -> add pattern. The dst is the trailing add output.
|
||||
void ggml_cuda_op_snake_fused(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * x,
|
||||
const ggml_tensor * a,
|
||||
const ggml_tensor * inv_b,
|
||||
ggml_tensor * dst) {
|
||||
launch_snake(ctx, x, a, inv_b, dst);
|
||||
}
|
||||
8
ggml/src/ggml-cuda/snake.cuh
Normal file
8
ggml/src/ggml-cuda/snake.cuh
Normal file
@@ -0,0 +1,8 @@
|
||||
#include "common.cuh"
|
||||
|
||||
// Fusion entry point. Caller supplies x/a/inv_b explicitly.
|
||||
void ggml_cuda_op_snake_fused(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * x,
|
||||
const ggml_tensor * a,
|
||||
const ggml_tensor * inv_b,
|
||||
ggml_tensor * dst);
|
||||
@@ -2,4 +2,5 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(192, 128, 1, 16);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
||||
|
||||
@@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(192, 128, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8);
|
||||
|
||||
@@ -2,4 +2,5 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(192, 128, 2, 16);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||
|
||||
@@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(192, 128, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8);
|
||||
|
||||
@@ -2,4 +2,5 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(192, 128, 4, 16);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
||||
|
||||
@@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(192, 128, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8);
|
||||
|
||||
@@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(192, 128, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8);
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(192, 128);
|
||||
@@ -3,7 +3,10 @@
|
||||
from glob import glob
|
||||
import os
|
||||
|
||||
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 320, 512, 576]
|
||||
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 192, 256, 320, 512, 576]
|
||||
|
||||
# DKQ -> DV override for asymmetric head dims.
|
||||
HEAD_SIZES_V_OVERRIDE = {576: 512, 320: 256, 192: 128}
|
||||
|
||||
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"]
|
||||
|
||||
@@ -62,7 +65,7 @@ for filename in glob("*.cu"):
|
||||
os.remove(filename)
|
||||
|
||||
for head_size_kq in HEAD_SIZES_KQ:
|
||||
head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512)
|
||||
head_size_v = HEAD_SIZES_V_OVERRIDE.get(head_size_kq, head_size_kq)
|
||||
with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v))
|
||||
|
||||
@@ -85,15 +88,17 @@ for ncols in [8, 16, 32, 64]:
|
||||
if head_size_kq == 72:
|
||||
continue
|
||||
# Skip compilation of unused ncols2 values for niche head sizes:
|
||||
if head_size_kq == 192 and ncols2 not in (8, 16): # MiMo-V2.5
|
||||
continue
|
||||
if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4
|
||||
continue
|
||||
if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4
|
||||
continue
|
||||
if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash
|
||||
continue
|
||||
if head_size_kq not in (320, 576) and ncols2 in (16, 32):
|
||||
if head_size_kq not in (192, 320, 576) and ncols2 in (16, 32):
|
||||
continue
|
||||
head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512)
|
||||
head_size_v = HEAD_SIZES_V_OVERRIDE.get(head_size_kq, head_size_kq)
|
||||
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
|
||||
|
||||
for type in TYPES_MMQ:
|
||||
|
||||
2
ggml/src/ggml-cuda/vendors/hip.h
vendored
2
ggml/src/ggml-cuda/vendors/hip.h
vendored
@@ -48,6 +48,7 @@
|
||||
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
||||
#define cublasSetStream hipblasSetStream
|
||||
#define cublasSgemm hipblasSgemm
|
||||
#define cublasSgemmStridedBatched hipblasSgemmStridedBatched
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define cublasOperation_t hipblasOperation_t
|
||||
#define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch
|
||||
@@ -55,6 +56,7 @@
|
||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
#define cudaDeviceGetAttribute hipDeviceGetAttribute
|
||||
#define cudaDeviceGetPCIBusId hipDeviceGetPCIBusId
|
||||
#define cudaDeviceProp hipDeviceProp_t
|
||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||
#define cudaError_t hipError_t
|
||||
|
||||
2
ggml/src/ggml-cuda/vendors/musa.h
vendored
2
ggml/src/ggml-cuda/vendors/musa.h
vendored
@@ -32,6 +32,7 @@
|
||||
#define cublasSetMathMode mublasSetMathMode
|
||||
#define cublasSetStream mublasSetStream
|
||||
#define cublasSgemm mublasSgemm
|
||||
#define cublasSgemmStridedBatched mublasSgemmStridedBatched
|
||||
#define cublasStatus_t mublasStatus_t
|
||||
#define cublasOperation_t mublasOperation_t
|
||||
#define cublasGetStatusString mublasGetStatusString
|
||||
@@ -39,6 +40,7 @@
|
||||
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
|
||||
#define cudaDeviceGetPCIBusId musaDeviceGetPCIBusId
|
||||
#define cudaDeviceProp musaDeviceProp
|
||||
#define cudaDeviceSynchronize musaDeviceSynchronize
|
||||
#define cudaError_t musaError_t
|
||||
|
||||
@@ -22,7 +22,8 @@ message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for
|
||||
include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
|
||||
include(ExternalProject)
|
||||
|
||||
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
|
||||
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
|
||||
option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF)
|
||||
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
|
||||
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)")
|
||||
|
||||
|
||||
@@ -48,14 +48,16 @@ using intvec = std::vector<int>;
|
||||
using uintvec = std::vector<unsigned int>;
|
||||
using u32vec = std::vector<uint32_t>;
|
||||
|
||||
static size_t opt_ndev = 1;
|
||||
static size_t opt_nhvx = 0; // use all
|
||||
static int opt_arch = 0; // autodetect
|
||||
static int opt_etm = 0;
|
||||
static int opt_verbose = 0;
|
||||
static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu)
|
||||
static int opt_hostbuf = 1; // hostbuf ON by default
|
||||
static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only
|
||||
static int opt_arch = 0; // autodetect
|
||||
static size_t opt_ndev = 1;
|
||||
static size_t opt_nhvx = 0; // use all
|
||||
static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only
|
||||
static size_t opt_vmem = HTP_OP_MAX_VMEM_DEFAULT; // max available va space for buffer mappings
|
||||
static size_t opt_mbuf = 1ul * 1024 * 1024 * 1024; // max buffer size
|
||||
static int opt_etm = 0;
|
||||
static int opt_verbose = 0;
|
||||
static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu)
|
||||
static int opt_hostbuf = 1; // hostbuf ON by default
|
||||
|
||||
// Default PMU events, if profiling with PMU (mode=2) is enabled
|
||||
// See https://docs.qualcomm.com/doc/80-N2040-60/topic/pmu-events.html
|
||||
@@ -66,6 +68,7 @@ static u32vec opt_pmu_evt { 0x3, 0x111, 0x100, 0x105, 0x240, 0x256, 0x7D, 0x8C }
|
||||
static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE;
|
||||
static int opt_opbatch = 1024; // max number of ops in a batch
|
||||
static int opt_opqueue = 16; // max number of pending batches
|
||||
|
||||
static std::regex* opt_opfilter = NULL; // regex of ops to not claim
|
||||
|
||||
#define HEX_VERBOSE(...) \
|
||||
@@ -110,7 +113,7 @@ static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct
|
||||
if (!opt_verbose) return;
|
||||
|
||||
op_desc desc(op);
|
||||
GGML_LOG_DEBUG("ggml-hex: %s supports-op %s : %s : %s : %s : %s : %s : %s\n", sess_name.c_str(),
|
||||
GGML_LOG_DEBUG("ggml-hex: %s supports-op %s: %s : %s : %s : %s : %s : %s\n", sess_name.c_str(),
|
||||
ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? "yes" : "no");
|
||||
}
|
||||
|
||||
@@ -118,8 +121,6 @@ static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_t
|
||||
uint32_t op_usec, uint32_t op_cycles, const uint32_t pmu[]) {
|
||||
if (!opt_profile) return;
|
||||
|
||||
op_desc desc(op);
|
||||
|
||||
char pmu_str[256] = "";
|
||||
if (opt_profile > 1) {
|
||||
static_assert(HTP_PROF_PMU_NCNT == 8, "current implementation assumes 8 PMU counters");
|
||||
@@ -127,6 +128,7 @@ static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_t
|
||||
pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]);
|
||||
}
|
||||
|
||||
op_desc desc(op);
|
||||
GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u%s\n", sess_name.c_str(),
|
||||
ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, op_usec, op_cycles, pmu_str);
|
||||
}
|
||||
@@ -191,33 +193,30 @@ struct ggml_hexagon_shared_buffer {
|
||||
bool mapped;
|
||||
bool pinned;
|
||||
|
||||
void mmap(bool pinned = false) {
|
||||
int err = fastrpc_mmap(sess->domain_id, this->fd, (void *) this->base, 0, this->size, FASTRPC_MAP_FD_DELAYED);
|
||||
void mmap() {
|
||||
fastrpc_map_flags flags = this->pinned ? FASTRPC_MAP_FD : FASTRPC_MAP_FD_DELAYED;
|
||||
|
||||
int err = fastrpc_mmap(sess->domain_id, this->fd, (void *) this->base, 0, this->size, flags);
|
||||
if (err != 0) {
|
||||
GGML_LOG_ERROR("ggml-hex: %s buffer mapping failed : domain_id %d size %zu fd %d error 0x%08x\n", sess->c_name(),
|
||||
sess->domain_id, this->size, this->fd, (unsigned) err);
|
||||
throw std::runtime_error("ggml-hex: fastrpc_mmap failed (see log for details)");
|
||||
}
|
||||
|
||||
if (pinned) {
|
||||
err = htp_iface_mmap(sess->handle, this->fd, this->size, pinned);
|
||||
if (err != 0) {
|
||||
GGML_LOG_ERROR("ggml-hex: %s buffer pinning failed : domain_id %d size %zu fd %d error 0x%08x\n", sess->c_name(),
|
||||
sess->domain_id, this->size, this->fd, (unsigned) err);
|
||||
throw std::runtime_error("ggml-hex: htp_iface_mmap failed (see log for details)");
|
||||
}
|
||||
}
|
||||
|
||||
this->mapped = true;
|
||||
this->pinned = pinned;
|
||||
HEX_VERBOSE("ggml-hex: %s mapped buffer: base %p size %zu fd %d pinned %u\n",
|
||||
sess->c_name(), (void *) this->base, this->size, this->fd, pinned);
|
||||
|
||||
this->mapped = true;
|
||||
}
|
||||
|
||||
void unmap() {
|
||||
if (!this->mapped) return;
|
||||
|
||||
htp_iface_munmap(sess->handle, this->fd);
|
||||
if (!this->pinned) {
|
||||
// HTP might still hold a reference, tell it drop it
|
||||
htp_iface_munmap(sess->handle, this->fd);
|
||||
}
|
||||
|
||||
fastrpc_munmap(sess->domain_id, this->fd, (void *) this->base, this->size);
|
||||
|
||||
HEX_VERBOSE("ggml-hex: %s unmapped buffer: base %p size %zu fd %d\n", sess->c_name(),
|
||||
@@ -227,7 +226,7 @@ struct ggml_hexagon_shared_buffer {
|
||||
this->fd = -1;
|
||||
}
|
||||
|
||||
void alloc(size_t size, bool pinned = false) {
|
||||
void alloc(size_t size) {
|
||||
if (this->base) return;
|
||||
|
||||
this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS, size);
|
||||
@@ -245,8 +244,7 @@ struct ggml_hexagon_shared_buffer {
|
||||
|
||||
HEX_VERBOSE("ggml-hex: %s allocated buffer: base %p size %zu fd %d pinned %d\n", sess->c_name(),
|
||||
(void *) this->base, this->size, this->fd, (int) pinned);
|
||||
|
||||
mmap(pinned);
|
||||
mmap();
|
||||
}
|
||||
|
||||
void free() {
|
||||
@@ -262,15 +260,14 @@ struct ggml_hexagon_shared_buffer {
|
||||
}
|
||||
|
||||
ggml_hexagon_shared_buffer(ggml_hexagon_session * sess, size_t size, bool pinned = false) {
|
||||
size += 4 * 1024; // extra page for padding
|
||||
|
||||
this->sess = sess;
|
||||
this->size = 0;
|
||||
this->base = nullptr;
|
||||
this->fd = -1;
|
||||
this->mapped = false;
|
||||
this->pinned = pinned;
|
||||
|
||||
alloc(size, pinned);
|
||||
alloc(size);
|
||||
}
|
||||
|
||||
~ggml_hexagon_shared_buffer() {
|
||||
@@ -1475,6 +1472,7 @@ static ggml_backend_buffer_t ggml_backend_hexagon_buffer_type_alloc_buffer(
|
||||
ggml_backend_buffer_type_t buffer_type, size_t size) {
|
||||
auto sess = static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->sess;
|
||||
try {
|
||||
size += 4 * 1024; // guard page
|
||||
ggml_hexagon_shared_buffer * sbuf = new ggml_hexagon_shared_buffer(sess, size);
|
||||
return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, sbuf, size);
|
||||
} catch (const std::exception & exc) {
|
||||
@@ -1487,6 +1485,7 @@ static ggml_backend_buffer_t ggml_backend_hexagon_repack_buffer_type_alloc_buffe
|
||||
ggml_backend_buffer_type_t buffer_type, size_t size) {
|
||||
auto sess = static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->sess;
|
||||
try {
|
||||
size += 4 * 1024; // guard page
|
||||
ggml_hexagon_shared_buffer * sbuf = new ggml_hexagon_shared_buffer(sess, size);
|
||||
return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, sbuf, size);
|
||||
} catch (const std::exception & exc) {
|
||||
@@ -1505,7 +1504,7 @@ static size_t ggml_backend_hexagon_buffer_type_get_alloc_size(ggml_backend_buffe
|
||||
}
|
||||
|
||||
static size_t ggml_backend_hexagon_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) {
|
||||
return 1UL * 1024 * 1024 * 1024; // 1GB per buffer
|
||||
return opt_mbuf; // typically 1GB per buffer
|
||||
GGML_UNUSED(buffer_type);
|
||||
}
|
||||
|
||||
@@ -1573,14 +1572,14 @@ struct ggml_hexagon_opbatch {
|
||||
d_map.clear();
|
||||
}
|
||||
|
||||
ggml_hexagon_opbatch(ggml_hexagon_session *sess, size_t batch_size) {
|
||||
ggml_hexagon_opbatch(ggml_hexagon_session *sess, size_t batch_size, size_t max_vmem) {
|
||||
this->sess = sess;
|
||||
|
||||
n_bufs_max = HTP_OP_MAX_BUFS;
|
||||
n_ops_max = batch_size;
|
||||
n_tens_max = n_ops_max + n_ops_max * HTP_OP_MAX_INPUTS;
|
||||
|
||||
b_vmem_max = HTP_OP_MAX_VMEM;
|
||||
b_vmem_max = max_vmem;
|
||||
|
||||
ops.resize(n_ops_max);
|
||||
|
||||
@@ -1592,6 +1591,9 @@ struct ggml_hexagon_opbatch {
|
||||
t_map.reserve(n_tens_max);
|
||||
d_map.reserve(n_tens_max);
|
||||
|
||||
GGML_LOG_INFO("ggml-hex: %s op batching: n-bufs %u n-tensors %u n-ops %u vmem %zu\n",
|
||||
sess->c_name(), n_bufs_max, n_tens_max, n_ops_max, b_vmem_max);
|
||||
|
||||
reset();
|
||||
}
|
||||
|
||||
@@ -1925,6 +1927,8 @@ void ggml_hexagon_session::flush_batch() {
|
||||
// Bump pending flag (cleared in the session::flush once we get the response)
|
||||
this->op_pending++; // atomic inc
|
||||
|
||||
HEX_VERBOSE("ggml-hex: %s queue-opbatch: %p size %u\n", this->c_name(), dbuf.ptr, dbuf.size);
|
||||
|
||||
int err = dspqueue_write(this->queue, 0, 1, &dbuf, sizeof(req), (const uint8_t*) &req, DSPQUEUE_TIMEOUT);
|
||||
if (err != 0) {
|
||||
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->c_name(), (unsigned) err);
|
||||
@@ -1944,6 +1948,35 @@ void ggml_hexagon_session::flush(bool all) {
|
||||
flush_pending(all);
|
||||
}
|
||||
|
||||
static size_t ggml_hexagon_measure_max_vmem(ggml_hexagon_session *sess) {
|
||||
// Allocate a bunch pinned buffers till failure.
|
||||
// This is kind of expensive but handy for figuring out exactly how much we can mmap on a specific device.
|
||||
// Typically we're going to allocate all/most of these buffers anyway for the model weights.
|
||||
|
||||
std::vector<ggml_hexagon_shared_buffer *> sbufs;
|
||||
|
||||
const size_t MiB = 1024 * 1024;
|
||||
const size_t GiB = MiB * 1024;
|
||||
|
||||
size_t vmem = 0;
|
||||
size_t step = 256u * MiB;
|
||||
|
||||
try {
|
||||
sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB;
|
||||
sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB;
|
||||
sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB;
|
||||
|
||||
while (1) {
|
||||
sbufs.push_back(new ggml_hexagon_shared_buffer(sess, step, true));
|
||||
vmem += step;
|
||||
}
|
||||
} catch (...) { }
|
||||
|
||||
for (auto b : sbufs) { delete b; }
|
||||
|
||||
return vmem - step; // backoff to account for overhead from internal mappings
|
||||
}
|
||||
|
||||
void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
|
||||
this->valid_session = false;
|
||||
this->valid_handle = false;
|
||||
@@ -1957,7 +1990,7 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
|
||||
|
||||
this->op_pending = 0;
|
||||
|
||||
GGML_LOG_INFO("ggml-hex: allocating new session: %s\n", this->name.c_str());
|
||||
GGML_LOG_DEBUG("ggml-hex: %s allocating new session\n", this->name.c_str());
|
||||
|
||||
domain * my_domain = get_domain(this->domain_id);
|
||||
if (my_domain == NULL) {
|
||||
@@ -2033,9 +2066,6 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
|
||||
|
||||
this->valid_handle = true;
|
||||
|
||||
GGML_LOG_INFO("ggml-hex: new session: %s : session-id %d domain-id %d uri %s handle 0x%lx\n", this->name.c_str(),
|
||||
this->session_id, this->domain_id, session_uri, (unsigned long) this->handle);
|
||||
|
||||
// Enable FastRPC QoS mode
|
||||
{
|
||||
struct remote_rpc_control_latency l;
|
||||
@@ -2047,6 +2077,9 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
|
||||
}
|
||||
}
|
||||
|
||||
GGML_LOG_INFO("ggml-hex: %s new session : session-id %d domain-id %d uri %s handle 0x%lx\n", this->c_name(),
|
||||
this->session_id, this->domain_id, session_uri, (unsigned long) this->handle);
|
||||
|
||||
const size_t req_q_size = (sizeof(htp_opbatch_req) * opt_opqueue * 2) + 1024;
|
||||
const size_t rsp_q_size = (sizeof(htp_opbatch_rsp) * opt_opqueue * 2) + 1024;
|
||||
|
||||
@@ -2091,13 +2124,19 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
|
||||
}
|
||||
|
||||
// Allocate buffers and state for op batching
|
||||
this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch);
|
||||
this->op_queue = new ggml_hexagon_opqueue(this, opt_opbatch, opt_opqueue);
|
||||
|
||||
// Start processing op batch requests
|
||||
err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx);
|
||||
if (!opt_vmem) {
|
||||
opt_vmem = ggml_hexagon_measure_max_vmem(this);
|
||||
GGML_LOG_INFO("ggml-hex: %s measured max vmem %zu\n", this->c_name(), opt_vmem);
|
||||
}
|
||||
|
||||
this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch, opt_vmem);
|
||||
|
||||
// Start dspqueue/opbatch processing
|
||||
err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx, opt_vmem);
|
||||
if (err != 0) {
|
||||
GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err);
|
||||
GGML_LOG_ERROR("ggml-hex: %s failed to start session: 0x%08x\n", this->c_name(), (unsigned) err);
|
||||
throw std::runtime_error("ggml-hex: iface start failed (see log for details)");
|
||||
}
|
||||
this->valid_iface = true;
|
||||
@@ -2108,17 +2147,17 @@ void ggml_hexagon_session::release() noexcept(true) {
|
||||
|
||||
int err;
|
||||
|
||||
delete this->op_batch;
|
||||
delete this->op_queue;
|
||||
|
||||
// Stop the DSP-side service and close the queue
|
||||
if (this->valid_iface) {
|
||||
// Stop dspqueue/opbatch processing
|
||||
err = htp_iface_stop(this->handle);
|
||||
if (err != 0) {
|
||||
GGML_ABORT("ggml-hex: htp_iface_stop failed: 0x%08x\n", (unsigned) err);
|
||||
}
|
||||
}
|
||||
|
||||
delete this->op_batch;
|
||||
delete this->op_queue;
|
||||
|
||||
if (opt_etm) {
|
||||
err = htp_iface_etm(this->handle, 0);
|
||||
if (err != 0) {
|
||||
@@ -2215,14 +2254,65 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->ne[2] != 1 || dst->ne[3] != 1) {
|
||||
// FA during prompt still needs work
|
||||
if (dst->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * q = op->src[0];
|
||||
const struct ggml_tensor * k = op->src[1];
|
||||
const struct ggml_tensor * v = op->src[2];
|
||||
const struct ggml_tensor * g = op->src[3];
|
||||
const struct ggml_tensor * beta = op->src[4];
|
||||
const struct ggml_tensor * state = op->src[5];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!q || !k || !v || !g || !beta || !state) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (q->type != GGML_TYPE_F32 || k->type != GGML_TYPE_F32 || v->type != GGML_TYPE_F32 ||
|
||||
g->type != GGML_TYPE_F32 || beta->type != GGML_TYPE_F32 || state->type != GGML_TYPE_F32 ||
|
||||
dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ggml_is_contiguous_rows(q) || !ggml_is_contiguous_rows(k) || !ggml_is_contiguous_rows(v) ||
|
||||
!ggml_is_contiguous(g) || !ggml_is_contiguous(beta) || !ggml_is_contiguous(state) ||
|
||||
!ggml_is_contiguous(dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H = v->ne[1];
|
||||
const int64_t n_tokens = v->ne[2];
|
||||
const int64_t n_seqs = v->ne[3];
|
||||
|
||||
if (S_v <= 0 || S_v > 128 || H <= 0 || n_tokens <= 0 || n_seqs <= 0) {
|
||||
return false;
|
||||
}
|
||||
if (q->ne[0] != S_v || k->ne[0] != S_v || q->ne[1] <= 0 || k->ne[1] <= 0 ||
|
||||
q->ne[2] != n_tokens || k->ne[2] != n_tokens || q->ne[3] <= 0 || k->ne[3] <= 0 ||
|
||||
(n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) {
|
||||
return false;
|
||||
}
|
||||
if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) {
|
||||
return false;
|
||||
}
|
||||
if (ggml_nelements(state) != S_v * S_v * H * n_seqs) {
|
||||
return false;
|
||||
}
|
||||
if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs) {
|
||||
return false;
|
||||
}
|
||||
|
||||
GGML_UNUSED(sess);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
@@ -2382,8 +2472,8 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: add support for non-contigiuos tensors
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
|
||||
// dst must be contiguous; src0 may be non-contiguous
|
||||
if (!ggml_is_contiguous(dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -2739,32 +2829,34 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) {
|
||||
|
||||
static htp_op_code op_remap_to_htp(const ggml_tensor * t) {
|
||||
switch (t->op) {
|
||||
case GGML_OP_FLASH_ATTN_EXT: return HTP_OP_FLASH_ATTN_EXT;
|
||||
case GGML_OP_MUL_MAT: return HTP_OP_MUL_MAT;
|
||||
case GGML_OP_MUL_MAT_ID: return HTP_OP_MUL_MAT_ID;
|
||||
case GGML_OP_MUL: return HTP_OP_MUL;
|
||||
case GGML_OP_ADD: return HTP_OP_ADD;
|
||||
case GGML_OP_ADD_ID: return HTP_OP_ADD_ID;
|
||||
case GGML_OP_SUB: return HTP_OP_SUB;
|
||||
case GGML_OP_DIV: return HTP_OP_DIV;
|
||||
case GGML_OP_CPY: return HTP_OP_CPY;
|
||||
case GGML_OP_CONT: return HTP_OP_CPY;
|
||||
case GGML_OP_GET_ROWS: return HTP_OP_GET_ROWS;
|
||||
case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS;
|
||||
case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS;
|
||||
case GGML_OP_ARGSORT: return HTP_OP_ARGSORT;
|
||||
case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM;
|
||||
case GGML_OP_SCALE: return HTP_OP_SCALE;
|
||||
case GGML_OP_SQR: return HTP_OP_SQR;
|
||||
case GGML_OP_SQRT: return HTP_OP_SQRT;
|
||||
case GGML_OP_SOFT_MAX: return HTP_OP_SOFTMAX;
|
||||
case GGML_OP_SSM_CONV: return HTP_OP_SSM_CONV;
|
||||
case GGML_OP_ROPE: return HTP_OP_ROPE;
|
||||
case GGML_OP_REPEAT: return HTP_OP_REPEAT;
|
||||
case GGML_OP_CUMSUM: return HTP_OP_CUMSUM;
|
||||
case GGML_OP_FILL: return HTP_OP_FILL;
|
||||
case GGML_OP_DIAG: return HTP_OP_DIAG;
|
||||
case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI;
|
||||
case GGML_OP_FLASH_ATTN_EXT: return HTP_OP_FLASH_ATTN_EXT;
|
||||
case GGML_OP_MUL_MAT: return HTP_OP_MUL_MAT;
|
||||
case GGML_OP_MUL_MAT_ID: return HTP_OP_MUL_MAT_ID;
|
||||
case GGML_OP_MUL: return HTP_OP_MUL;
|
||||
case GGML_OP_ADD: return HTP_OP_ADD;
|
||||
case GGML_OP_ADD_ID: return HTP_OP_ADD_ID;
|
||||
case GGML_OP_SUB: return HTP_OP_SUB;
|
||||
case GGML_OP_DIV: return HTP_OP_DIV;
|
||||
case GGML_OP_CPY: return HTP_OP_CPY;
|
||||
case GGML_OP_CONT: return HTP_OP_CPY;
|
||||
case GGML_OP_GET_ROWS: return HTP_OP_GET_ROWS;
|
||||
case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS;
|
||||
case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS;
|
||||
case GGML_OP_ARGSORT: return HTP_OP_ARGSORT;
|
||||
case GGML_OP_L2_NORM: return HTP_OP_L2_NORM;
|
||||
case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM;
|
||||
case GGML_OP_SCALE: return HTP_OP_SCALE;
|
||||
case GGML_OP_SQR: return HTP_OP_SQR;
|
||||
case GGML_OP_SQRT: return HTP_OP_SQRT;
|
||||
case GGML_OP_SOFT_MAX: return HTP_OP_SOFTMAX;
|
||||
case GGML_OP_SSM_CONV: return HTP_OP_SSM_CONV;
|
||||
case GGML_OP_GATED_DELTA_NET: return HTP_OP_GATED_DELTA_NET;
|
||||
case GGML_OP_ROPE: return HTP_OP_ROPE;
|
||||
case GGML_OP_REPEAT: return HTP_OP_REPEAT;
|
||||
case GGML_OP_CUMSUM: return HTP_OP_CUMSUM;
|
||||
case GGML_OP_FILL: return HTP_OP_FILL;
|
||||
case GGML_OP_DIAG: return HTP_OP_DIAG;
|
||||
case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(t)) {
|
||||
case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU;
|
||||
@@ -2997,8 +3089,8 @@ static struct ggml_backend_i hexagon_backend_i = {
|
||||
/* .free = */ ggml_backend_hexagon_free,
|
||||
/* .set_tensor_async = */ NULL,
|
||||
/* .get_tensor_async = */ NULL,
|
||||
/* .get_tensor_2d_async = */ NULL,
|
||||
/* .set_tensor_2d_async = */ NULL,
|
||||
/* .get_tensor_2d_async = */ NULL,
|
||||
/* .cpy_tensor_async = */ NULL,
|
||||
/* .synchronize = */ ggml_backend_hexagon_synchronize,
|
||||
/* .graph_plan_create = */ NULL,
|
||||
@@ -3215,6 +3307,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
supp = ggml_hexagon_supported_add_id(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_L2_NORM:
|
||||
supp = ggml_hexagon_supported_unary(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_SCALE:
|
||||
supp = ggml_hexagon_supported_unary(sess, op);
|
||||
@@ -3298,6 +3394,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
supp = ggml_hexagon_supported_ssm_conv(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
supp = ggml_hexagon_supported_gated_delta_net(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_CUMSUM:
|
||||
supp = ggml_hexagon_supported_cumsum(sess, op);
|
||||
break;
|
||||
@@ -3380,21 +3480,6 @@ struct ggml_hexagon_registry {
|
||||
ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) {
|
||||
GGML_LOG_INFO("ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev %zu\n", opt_ndev);
|
||||
|
||||
if (!opt_arch) {
|
||||
int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch);
|
||||
if (err != 0) {
|
||||
GGML_LOG_ERROR("ggml-hex: failed to query HTP version (err %d) defaulting to v73\n", err);
|
||||
opt_arch = 73;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
if (opt_arch < 75) {
|
||||
opt_ndev = 1;
|
||||
GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch);
|
||||
|
||||
// Create devices / sessions
|
||||
@@ -3480,32 +3565,67 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
static_assert((unsigned int) HTP_TYPE_IQ4_NL == (unsigned int) GGML_TYPE_IQ4_NL,
|
||||
"please update hexagon_type to match ggml_type");
|
||||
|
||||
const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE");
|
||||
const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF");
|
||||
const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE");
|
||||
const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH");
|
||||
const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE");
|
||||
const char * str_opfilter= getenv("GGML_HEXAGON_OPFILTER");
|
||||
const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
|
||||
const char * str_etm = getenv("GGML_HEXAGON_ETM");
|
||||
const char * str_nhvx = getenv("GGML_HEXAGON_NHVX");
|
||||
const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX");
|
||||
const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
|
||||
const char * str_arch = getenv("GGML_HEXAGON_ARCH");
|
||||
const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE");
|
||||
const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF");
|
||||
const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE");
|
||||
const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH");
|
||||
const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE");
|
||||
const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER");
|
||||
const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
|
||||
const char * str_etm = getenv("GGML_HEXAGON_ETM");
|
||||
const char * str_nhvx = getenv("GGML_HEXAGON_NHVX");
|
||||
const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX");
|
||||
const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
|
||||
const char * str_arch = getenv("GGML_HEXAGON_ARCH");
|
||||
const char * str_vmem = getenv("GGML_HEXAGON_VMEM");
|
||||
const char * str_mbuf = getenv("GGML_HEXAGON_MBUF");
|
||||
|
||||
// Init Arch first since it affects other defaults
|
||||
if (!str_arch) {
|
||||
int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch);
|
||||
if (err != 0) {
|
||||
GGML_LOG_ERROR("ggml-hex: failed to query HTP version (err %d) defaulting to v73\n", err);
|
||||
opt_arch = 73;
|
||||
}
|
||||
} else {
|
||||
if (str_arch[0] == 'v' || str_arch[0] == 'V') {
|
||||
str_arch++;
|
||||
}
|
||||
opt_arch = strtoul(str_arch, NULL, 0);
|
||||
}
|
||||
|
||||
size_t MiB = 1024 * 1024;
|
||||
|
||||
// Update vmem default
|
||||
opt_vmem = opt_arch >= 75 ? HTP_OP_MAX_VMEM_DEFAULT : 3000 * MiB;
|
||||
|
||||
auto RE_ICASE = std::regex_constants::icase;
|
||||
|
||||
opt_opfilter = str_opfilter ? new std::regex(str_opfilter, RE_ICASE) : NULL;
|
||||
opt_verbose = str_verbose ? atoi(str_verbose) : 0;
|
||||
opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf;
|
||||
opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage;
|
||||
opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch;
|
||||
opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue;
|
||||
opt_etm = str_etm ? atoi(str_etm) : 0;
|
||||
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
|
||||
opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx;
|
||||
opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev;
|
||||
opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf;
|
||||
opt_opfilter = str_opfilter ? new std::regex(str_opfilter, RE_ICASE) : NULL;
|
||||
opt_verbose = str_verbose ? atoi(str_verbose) : 0;
|
||||
opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf;
|
||||
opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage;
|
||||
opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch;
|
||||
opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue;
|
||||
opt_profile = str_profile ? atoi(str_profile) : 0;
|
||||
opt_etm = str_etm ? atoi(str_etm) : 0;
|
||||
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
|
||||
opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx;
|
||||
opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev;
|
||||
opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf;
|
||||
opt_mbuf = str_mbuf ? strtoul(str_mbuf, NULL, 0) * MiB : opt_mbuf;
|
||||
opt_vmem = str_vmem ? strtoul(str_vmem, NULL, 0) * MiB : opt_vmem;
|
||||
|
||||
if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) {
|
||||
opt_ndev = GGML_HEXAGON_MAX_SESSIONS;
|
||||
}
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
if (opt_arch < 75) {
|
||||
opt_ndev = 1;
|
||||
GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
if (str_profile) {
|
||||
opt_pmu_evt = [&]() -> std::vector<uint32_t> {
|
||||
@@ -3520,17 +3640,6 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
vec_to_str<uint32_t, 16>(opt_pmu_evt).c_str());
|
||||
}
|
||||
|
||||
if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) {
|
||||
opt_ndev = GGML_HEXAGON_MAX_SESSIONS;
|
||||
}
|
||||
|
||||
if (str_arch) {
|
||||
if (str_arch[0] == 'v') {
|
||||
str_arch++;
|
||||
}
|
||||
opt_arch = strtoul(str_arch, NULL, 0);
|
||||
}
|
||||
|
||||
reg->context = new ggml_hexagon_registry(reg);
|
||||
}
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ add_library(${HTP_LIB} SHARED
|
||||
fill-ops.c
|
||||
diag-ops.c
|
||||
solve-tri-ops.c
|
||||
gated-delta-net-ops.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
@@ -44,6 +45,11 @@ target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
|
||||
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
|
||||
|
||||
if (GGML_HEXAGON_FA_EXP2_HF)
|
||||
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
|
||||
endif()
|
||||
|
||||
# HMX acceleration: available on v73+ architectures
|
||||
set(HTP_HMX_VERSIONS v73 v75 v79 v81)
|
||||
list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
|
||||
@@ -52,11 +58,13 @@ if (_hmx_idx GREATER_EQUAL 0)
|
||||
target_sources(${HTP_LIB} PRIVATE
|
||||
hmx-queue.c
|
||||
hmx-matmul-ops.c
|
||||
hmx-flash-attn-ops.c
|
||||
)
|
||||
|
||||
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
|
||||
set_source_files_properties(
|
||||
hmx-matmul-ops.c
|
||||
hmx-flash-attn-ops.c
|
||||
PROPERTIES COMPILE_OPTIONS "-mhmx"
|
||||
)
|
||||
|
||||
|
||||
@@ -138,15 +138,15 @@ set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,")
|
||||
set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,")
|
||||
|
||||
#Compiler Options
|
||||
set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
|
||||
set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
|
||||
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O3")
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O2")
|
||||
|
||||
set(CMAKE_C_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
|
||||
set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g")
|
||||
set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O3")
|
||||
set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g")
|
||||
set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O2")
|
||||
|
||||
set(CMAKE_ASM_FLAGS_DEBUG "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG}")
|
||||
set(CMAKE_ASM_FLAGS_RELEASE "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE}")
|
||||
|
||||
@@ -17,13 +17,14 @@
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hmx-ops.h"
|
||||
|
||||
// Must be multiple of 32
|
||||
#define FLASH_ATTN_BLOCK_SIZE (32 * 2)
|
||||
|
||||
// This is a bit of a hack because the compiler is strugling to properly inline
|
||||
// the default hvx_vec_f32_to_f16 with output into the local array.
|
||||
static void __attribute__((noinline)) hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
|
||||
static __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
|
||||
{
|
||||
*(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1);
|
||||
}
|
||||
@@ -621,6 +622,17 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
// HMX path: prefill (neq1 >= 32), head_dim multiple of 32, F16 KV
|
||||
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0 && q->ne[1] >= 32) {
|
||||
int ret = hmx_flash_attn_ext(octx);
|
||||
if (ret == HTP_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
// VTCM too small or other failure -> fall through to HVX path
|
||||
}
|
||||
#endif
|
||||
|
||||
struct htp_fa_context factx;
|
||||
factx.octx = octx;
|
||||
|
||||
|
||||
955
ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c
Normal file
955
ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c
Normal file
@@ -0,0 +1,955 @@
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
|
||||
#ifndef MIN
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
#define HTP_GDN_MAX_SV 128
|
||||
|
||||
struct htp_gdn_context {
|
||||
struct htp_ops_context * octx;
|
||||
uint32_t rows_per_thread;
|
||||
size_t state_bytes;
|
||||
bool use_vtcm;
|
||||
uint8_t * vtcm_state_base;
|
||||
size_t vtcm_state_per_thread;
|
||||
};
|
||||
|
||||
static inline float gdn_mul_dot_f32(float * restrict dst, const float * restrict mul,
|
||||
const float * restrict dot, uint32_t n) {
|
||||
HVX_Vector acc = Q6_V_vzero();
|
||||
|
||||
const uint32_t epv = 128 / sizeof(float);
|
||||
const uint32_t nvec = n / epv;
|
||||
const uint32_t tail = n % epv;
|
||||
for (uint32_t i = 0; i < nvec; ++i) {
|
||||
HVX_Vector vd = hvx_vmemu(dst + i * epv);
|
||||
HVX_Vector vm = hvx_vmem(mul + i * epv);
|
||||
HVX_Vector vdot = hvx_vmem(dot + i * epv);
|
||||
HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm);
|
||||
hvx_vmemu(dst + i * epv) = out;
|
||||
acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot));
|
||||
}
|
||||
|
||||
if (tail) {
|
||||
const uint32_t off = nvec * epv;
|
||||
HVX_Vector vd = hvx_vmemu(dst + off);
|
||||
HVX_Vector vm = hvx_vmem(mul + off);
|
||||
HVX_Vector vdot = hvx_vmem(dot + off);
|
||||
HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm);
|
||||
hvx_vec_store_u(dst + off, tail * sizeof(float), out);
|
||||
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
|
||||
HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot);
|
||||
acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero()));
|
||||
}
|
||||
|
||||
return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc));
|
||||
}
|
||||
|
||||
static inline float gdn_mul_scalar_dot_f32(float * restrict dst, float mul,
|
||||
const float * restrict dot, uint32_t n) {
|
||||
HVX_Vector acc = Q6_V_vzero();
|
||||
const HVX_Vector vmul = hvx_vec_splat_f32(mul);
|
||||
|
||||
const uint32_t epv = 128 / sizeof(float);
|
||||
const uint32_t nvec = n / epv;
|
||||
const uint32_t tail = n % epv;
|
||||
for (uint32_t i = 0; i < nvec; ++i) {
|
||||
HVX_Vector vd = hvx_vmemu(dst + i * epv);
|
||||
HVX_Vector vdot = hvx_vmem(dot + i * epv);
|
||||
HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul);
|
||||
hvx_vmemu(dst + i * epv) = out;
|
||||
acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot));
|
||||
}
|
||||
|
||||
if (tail) {
|
||||
const uint32_t off = nvec * epv;
|
||||
HVX_Vector vd = hvx_vmemu(dst + off);
|
||||
HVX_Vector vdot = hvx_vmem(dot + off);
|
||||
HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul);
|
||||
hvx_vec_store_u(dst + off, tail * sizeof(float), out);
|
||||
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
|
||||
HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot);
|
||||
acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero()));
|
||||
}
|
||||
|
||||
return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc));
|
||||
}
|
||||
|
||||
static inline float gdn_add_scaled_dot_f32(float * restrict dst, const float * restrict src,
|
||||
float scale, const float * restrict dot, uint32_t n) {
|
||||
HVX_Vector acc = Q6_V_vzero();
|
||||
const HVX_Vector vscale = hvx_vec_splat_f32(scale);
|
||||
|
||||
const uint32_t epv = 128 / sizeof(float);
|
||||
const uint32_t nvec = n / epv;
|
||||
const uint32_t tail = n % epv;
|
||||
for (uint32_t i = 0; i < nvec; ++i) {
|
||||
HVX_Vector vd = hvx_vmemu(dst + i * epv);
|
||||
HVX_Vector vs = hvx_vmem(src + i * epv);
|
||||
HVX_Vector vdot = hvx_vmem(dot + i * epv);
|
||||
HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale));
|
||||
hvx_vmemu(dst + i * epv) = out;
|
||||
acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot));
|
||||
}
|
||||
|
||||
if (tail) {
|
||||
const uint32_t off = nvec * epv;
|
||||
HVX_Vector vd = hvx_vmemu(dst + off);
|
||||
HVX_Vector vs = hvx_vmem(src + off);
|
||||
HVX_Vector vdot = hvx_vmem(dot + off);
|
||||
HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale));
|
||||
hvx_vec_store_u(dst + off, tail * sizeof(float), out);
|
||||
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
|
||||
HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot);
|
||||
acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero()));
|
||||
}
|
||||
|
||||
return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc));
|
||||
}
|
||||
|
||||
static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1,
|
||||
float * restrict dst2, float * restrict dst3, const float * restrict mul,
|
||||
const float * restrict dot, uint32_t n, float * restrict sums) {
|
||||
HVX_Vector acc0 = Q6_V_vzero();
|
||||
HVX_Vector acc1 = Q6_V_vzero();
|
||||
HVX_Vector acc2 = Q6_V_vzero();
|
||||
HVX_Vector acc3 = Q6_V_vzero();
|
||||
|
||||
const uint32_t epv = 128 / sizeof(float);
|
||||
const uint32_t nvec = n / epv;
|
||||
const uint32_t tail = n % epv;
|
||||
for (uint32_t i = 0; i < nvec; ++i) {
|
||||
HVX_Vector vm = hvx_vmem(mul + i * epv);
|
||||
HVX_Vector vdot = hvx_vmem(dot + i * epv);
|
||||
|
||||
HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vm);
|
||||
HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vm);
|
||||
HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vm);
|
||||
HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vm);
|
||||
|
||||
hvx_vmemu(dst0 + i * epv) = out0;
|
||||
hvx_vmemu(dst1 + i * epv) = out1;
|
||||
hvx_vmemu(dst2 + i * epv) = out2;
|
||||
hvx_vmemu(dst3 + i * epv) = out3;
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
|
||||
}
|
||||
|
||||
if (tail) {
|
||||
const uint32_t off = nvec * epv;
|
||||
HVX_Vector vm = hvx_vmem(mul + off);
|
||||
HVX_Vector vdot = hvx_vmem(dot + off);
|
||||
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
|
||||
HVX_Vector zero = Q6_V_vzero();
|
||||
|
||||
HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm);
|
||||
HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vm);
|
||||
HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm);
|
||||
HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm);
|
||||
|
||||
hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
|
||||
hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
|
||||
hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
|
||||
hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
|
||||
}
|
||||
|
||||
HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } };
|
||||
hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc));
|
||||
}
|
||||
|
||||
static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restrict dst1,
|
||||
float * restrict dst2, float * restrict dst3, float mul,
|
||||
const float * restrict dot, uint32_t n, float * restrict sums) {
|
||||
HVX_Vector acc0 = Q6_V_vzero();
|
||||
HVX_Vector acc1 = Q6_V_vzero();
|
||||
HVX_Vector acc2 = Q6_V_vzero();
|
||||
HVX_Vector acc3 = Q6_V_vzero();
|
||||
const HVX_Vector vmul = hvx_vec_splat_f32(mul);
|
||||
|
||||
const uint32_t epv = 128 / sizeof(float);
|
||||
const uint32_t nvec = n / epv;
|
||||
const uint32_t tail = n % epv;
|
||||
for (uint32_t i = 0; i < nvec; ++i) {
|
||||
HVX_Vector vdot = hvx_vmem(dot + i * epv);
|
||||
|
||||
HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vmul);
|
||||
HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vmul);
|
||||
HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vmul);
|
||||
HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vmul);
|
||||
|
||||
hvx_vmemu(dst0 + i * epv) = out0;
|
||||
hvx_vmemu(dst1 + i * epv) = out1;
|
||||
hvx_vmemu(dst2 + i * epv) = out2;
|
||||
hvx_vmemu(dst3 + i * epv) = out3;
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
|
||||
}
|
||||
|
||||
if (tail) {
|
||||
const uint32_t off = nvec * epv;
|
||||
HVX_Vector vdot = hvx_vmem(dot + off);
|
||||
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
|
||||
HVX_Vector zero = Q6_V_vzero();
|
||||
|
||||
HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul);
|
||||
HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vmul);
|
||||
HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul);
|
||||
HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul);
|
||||
|
||||
hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
|
||||
hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
|
||||
hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
|
||||
hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
|
||||
}
|
||||
|
||||
HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } };
|
||||
hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc));
|
||||
}
|
||||
|
||||
static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restrict dst1,
|
||||
float * restrict dst2, float * restrict dst3, const float * restrict src,
|
||||
const float * restrict scale, const float * restrict dot, uint32_t n,
|
||||
float * restrict sums) {
|
||||
HVX_Vector acc0 = Q6_V_vzero();
|
||||
HVX_Vector acc1 = Q6_V_vzero();
|
||||
HVX_Vector acc2 = Q6_V_vzero();
|
||||
HVX_Vector acc3 = Q6_V_vzero();
|
||||
const HVX_Vector scale0 = hvx_vec_splat_f32(scale[0]);
|
||||
const HVX_Vector scale1 = hvx_vec_splat_f32(scale[1]);
|
||||
const HVX_Vector scale2 = hvx_vec_splat_f32(scale[2]);
|
||||
const HVX_Vector scale3 = hvx_vec_splat_f32(scale[3]);
|
||||
|
||||
const uint32_t epv = 128 / sizeof(float);
|
||||
const uint32_t nvec = n / epv;
|
||||
const uint32_t tail = n % epv;
|
||||
for (uint32_t i = 0; i < nvec; ++i) {
|
||||
HVX_Vector vs = hvx_vmem(src + i * epv);
|
||||
HVX_Vector vdot = hvx_vmem(dot + i * epv);
|
||||
|
||||
HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + i * epv), hvx_vec_mul_f32_f32(vs, scale0));
|
||||
HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + i * epv), hvx_vec_mul_f32_f32(vs, scale1));
|
||||
HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + i * epv), hvx_vec_mul_f32_f32(vs, scale2));
|
||||
HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + i * epv), hvx_vec_mul_f32_f32(vs, scale3));
|
||||
|
||||
hvx_vmemu(dst0 + i * epv) = out0;
|
||||
hvx_vmemu(dst1 + i * epv) = out1;
|
||||
hvx_vmemu(dst2 + i * epv) = out2;
|
||||
hvx_vmemu(dst3 + i * epv) = out3;
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
|
||||
}
|
||||
|
||||
if (tail) {
|
||||
const uint32_t off = nvec * epv;
|
||||
HVX_Vector vs = hvx_vmem(src + off);
|
||||
HVX_Vector vdot = hvx_vmem(dot + off);
|
||||
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
|
||||
HVX_Vector zero = Q6_V_vzero();
|
||||
|
||||
HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0));
|
||||
HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + off), hvx_vec_mul_f32_f32(vs, scale1));
|
||||
HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2));
|
||||
HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3));
|
||||
|
||||
hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
|
||||
hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
|
||||
hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
|
||||
hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
|
||||
}
|
||||
|
||||
HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } };
|
||||
hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc));
|
||||
}
|
||||
|
||||
static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1,
|
||||
float * restrict dst2, float * restrict dst3, float * restrict dst4,
|
||||
float * restrict dst5, float * restrict dst6, float * restrict dst7,
|
||||
const float * restrict mul, const float * restrict dot, uint32_t n,
|
||||
float * restrict sums) {
|
||||
HVX_Vector acc0 = Q6_V_vzero();
|
||||
HVX_Vector acc1 = Q6_V_vzero();
|
||||
HVX_Vector acc2 = Q6_V_vzero();
|
||||
HVX_Vector acc3 = Q6_V_vzero();
|
||||
HVX_Vector acc4 = Q6_V_vzero();
|
||||
HVX_Vector acc5 = Q6_V_vzero();
|
||||
HVX_Vector acc6 = Q6_V_vzero();
|
||||
HVX_Vector acc7 = Q6_V_vzero();
|
||||
|
||||
const uint32_t epv = 128 / sizeof(float);
|
||||
const uint32_t nvec = n / epv;
|
||||
const uint32_t tail = n % epv;
|
||||
for (uint32_t i = 0; i < nvec; ++i) {
|
||||
HVX_Vector vm = hvx_vmem(mul + i * epv);
|
||||
HVX_Vector vdot = hvx_vmem(dot + i * epv);
|
||||
|
||||
HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vm);
|
||||
HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vm);
|
||||
HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vm);
|
||||
HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vm);
|
||||
HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + i * epv), vm);
|
||||
HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + i * epv), vm);
|
||||
HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + i * epv), vm);
|
||||
HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + i * epv), vm);
|
||||
|
||||
hvx_vmemu(dst0 + i * epv) = out0;
|
||||
hvx_vmemu(dst1 + i * epv) = out1;
|
||||
hvx_vmemu(dst2 + i * epv) = out2;
|
||||
hvx_vmemu(dst3 + i * epv) = out3;
|
||||
hvx_vmemu(dst4 + i * epv) = out4;
|
||||
hvx_vmemu(dst5 + i * epv) = out5;
|
||||
hvx_vmemu(dst6 + i * epv) = out6;
|
||||
hvx_vmemu(dst7 + i * epv) = out7;
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
|
||||
acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot));
|
||||
acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot));
|
||||
acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot));
|
||||
acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot));
|
||||
}
|
||||
|
||||
if (tail) {
|
||||
const uint32_t off = nvec * epv;
|
||||
HVX_Vector vm = hvx_vmem(mul + off);
|
||||
HVX_Vector vdot = hvx_vmem(dot + off);
|
||||
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
|
||||
HVX_Vector zero = Q6_V_vzero();
|
||||
|
||||
HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm);
|
||||
HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vm);
|
||||
HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm);
|
||||
HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm);
|
||||
HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + off), vm);
|
||||
HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + off), vm);
|
||||
HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vm);
|
||||
HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vm);
|
||||
|
||||
hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
|
||||
hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
|
||||
hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
|
||||
hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
|
||||
hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4);
|
||||
hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5);
|
||||
hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6);
|
||||
hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7);
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
|
||||
acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero));
|
||||
acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero));
|
||||
acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero));
|
||||
acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero));
|
||||
}
|
||||
|
||||
HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } };
|
||||
HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } };
|
||||
hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA));
|
||||
hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB));
|
||||
}
|
||||
|
||||
static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restrict dst1,
|
||||
float * restrict dst2, float * restrict dst3, float * restrict dst4,
|
||||
float * restrict dst5, float * restrict dst6, float * restrict dst7,
|
||||
float mul, const float * restrict dot, uint32_t n, float * restrict sums) {
|
||||
HVX_Vector acc0 = Q6_V_vzero();
|
||||
HVX_Vector acc1 = Q6_V_vzero();
|
||||
HVX_Vector acc2 = Q6_V_vzero();
|
||||
HVX_Vector acc3 = Q6_V_vzero();
|
||||
HVX_Vector acc4 = Q6_V_vzero();
|
||||
HVX_Vector acc5 = Q6_V_vzero();
|
||||
HVX_Vector acc6 = Q6_V_vzero();
|
||||
HVX_Vector acc7 = Q6_V_vzero();
|
||||
const HVX_Vector vmul = hvx_vec_splat_f32(mul);
|
||||
|
||||
const uint32_t epv = 128 / sizeof(float);
|
||||
const uint32_t nvec = n / epv;
|
||||
const uint32_t tail = n % epv;
|
||||
for (uint32_t i = 0; i < nvec; ++i) {
|
||||
HVX_Vector vdot = hvx_vmem(dot + i * epv);
|
||||
|
||||
HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vmul);
|
||||
HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vmul);
|
||||
HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vmul);
|
||||
HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vmul);
|
||||
HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + i * epv), vmul);
|
||||
HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + i * epv), vmul);
|
||||
HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + i * epv), vmul);
|
||||
HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + i * epv), vmul);
|
||||
|
||||
hvx_vmemu(dst0 + i * epv) = out0;
|
||||
hvx_vmemu(dst1 + i * epv) = out1;
|
||||
hvx_vmemu(dst2 + i * epv) = out2;
|
||||
hvx_vmemu(dst3 + i * epv) = out3;
|
||||
hvx_vmemu(dst4 + i * epv) = out4;
|
||||
hvx_vmemu(dst5 + i * epv) = out5;
|
||||
hvx_vmemu(dst6 + i * epv) = out6;
|
||||
hvx_vmemu(dst7 + i * epv) = out7;
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
|
||||
acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot));
|
||||
acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot));
|
||||
acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot));
|
||||
acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot));
|
||||
}
|
||||
|
||||
if (tail) {
|
||||
const uint32_t off = nvec * epv;
|
||||
HVX_Vector vdot = hvx_vmem(dot + off);
|
||||
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
|
||||
HVX_Vector zero = Q6_V_vzero();
|
||||
|
||||
HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul);
|
||||
HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vmul);
|
||||
HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul);
|
||||
HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul);
|
||||
HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + off), vmul);
|
||||
HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + off), vmul);
|
||||
HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vmul);
|
||||
HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vmul);
|
||||
|
||||
hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
|
||||
hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
|
||||
hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
|
||||
hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
|
||||
hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4);
|
||||
hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5);
|
||||
hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6);
|
||||
hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7);
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
|
||||
acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero));
|
||||
acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero));
|
||||
acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero));
|
||||
acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero));
|
||||
}
|
||||
|
||||
HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } };
|
||||
HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } };
|
||||
hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA));
|
||||
hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB));
|
||||
}
|
||||
|
||||
static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restrict dst1,
|
||||
float * restrict dst2, float * restrict dst3, float * restrict dst4,
|
||||
float * restrict dst5, float * restrict dst6, float * restrict dst7,
|
||||
const float * restrict src, const float * restrict scale,
|
||||
const float * restrict dot, uint32_t n, float * restrict sums) {
|
||||
HVX_Vector acc0 = Q6_V_vzero();
|
||||
HVX_Vector acc1 = Q6_V_vzero();
|
||||
HVX_Vector acc2 = Q6_V_vzero();
|
||||
HVX_Vector acc3 = Q6_V_vzero();
|
||||
HVX_Vector acc4 = Q6_V_vzero();
|
||||
HVX_Vector acc5 = Q6_V_vzero();
|
||||
HVX_Vector acc6 = Q6_V_vzero();
|
||||
HVX_Vector acc7 = Q6_V_vzero();
|
||||
const HVX_Vector scale0 = hvx_vec_splat_f32(scale[0]);
|
||||
const HVX_Vector scale1 = hvx_vec_splat_f32(scale[1]);
|
||||
const HVX_Vector scale2 = hvx_vec_splat_f32(scale[2]);
|
||||
const HVX_Vector scale3 = hvx_vec_splat_f32(scale[3]);
|
||||
const HVX_Vector scale4 = hvx_vec_splat_f32(scale[4]);
|
||||
const HVX_Vector scale5 = hvx_vec_splat_f32(scale[5]);
|
||||
const HVX_Vector scale6 = hvx_vec_splat_f32(scale[6]);
|
||||
const HVX_Vector scale7 = hvx_vec_splat_f32(scale[7]);
|
||||
|
||||
const uint32_t epv = 128 / sizeof(float);
|
||||
const uint32_t nvec = n / epv;
|
||||
const uint32_t tail = n % epv;
|
||||
for (uint32_t i = 0; i < nvec; ++i) {
|
||||
HVX_Vector vs = hvx_vmem(src + i * epv);
|
||||
HVX_Vector vdot = hvx_vmem(dot + i * epv);
|
||||
|
||||
HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + i * epv), hvx_vec_mul_f32_f32(vs, scale0));
|
||||
HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + i * epv), hvx_vec_mul_f32_f32(vs, scale1));
|
||||
HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + i * epv), hvx_vec_mul_f32_f32(vs, scale2));
|
||||
HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + i * epv), hvx_vec_mul_f32_f32(vs, scale3));
|
||||
HVX_Vector out4 = hvx_vec_add_f32_f32(hvx_vmemu(dst4 + i * epv), hvx_vec_mul_f32_f32(vs, scale4));
|
||||
HVX_Vector out5 = hvx_vec_add_f32_f32(hvx_vmemu(dst5 + i * epv), hvx_vec_mul_f32_f32(vs, scale5));
|
||||
HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + i * epv), hvx_vec_mul_f32_f32(vs, scale6));
|
||||
HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + i * epv), hvx_vec_mul_f32_f32(vs, scale7));
|
||||
|
||||
hvx_vmemu(dst0 + i * epv) = out0;
|
||||
hvx_vmemu(dst1 + i * epv) = out1;
|
||||
hvx_vmemu(dst2 + i * epv) = out2;
|
||||
hvx_vmemu(dst3 + i * epv) = out3;
|
||||
hvx_vmemu(dst4 + i * epv) = out4;
|
||||
hvx_vmemu(dst5 + i * epv) = out5;
|
||||
hvx_vmemu(dst6 + i * epv) = out6;
|
||||
hvx_vmemu(dst7 + i * epv) = out7;
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
|
||||
acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot));
|
||||
acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot));
|
||||
acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot));
|
||||
acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot));
|
||||
}
|
||||
|
||||
if (tail) {
|
||||
const uint32_t off = nvec * epv;
|
||||
HVX_Vector vs = hvx_vmem(src + off);
|
||||
HVX_Vector vdot = hvx_vmem(dot + off);
|
||||
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
|
||||
HVX_Vector zero = Q6_V_vzero();
|
||||
|
||||
HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0));
|
||||
HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + off), hvx_vec_mul_f32_f32(vs, scale1));
|
||||
HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2));
|
||||
HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3));
|
||||
HVX_Vector out4 = hvx_vec_add_f32_f32(hvx_vmemu(dst4 + off), hvx_vec_mul_f32_f32(vs, scale4));
|
||||
HVX_Vector out5 = hvx_vec_add_f32_f32(hvx_vmemu(dst5 + off), hvx_vec_mul_f32_f32(vs, scale5));
|
||||
HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + off), hvx_vec_mul_f32_f32(vs, scale6));
|
||||
HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + off), hvx_vec_mul_f32_f32(vs, scale7));
|
||||
|
||||
hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
|
||||
hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
|
||||
hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
|
||||
hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
|
||||
hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4);
|
||||
hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5);
|
||||
hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6);
|
||||
hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7);
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
|
||||
acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero));
|
||||
acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero));
|
||||
acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero));
|
||||
acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero));
|
||||
}
|
||||
|
||||
HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } };
|
||||
HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } };
|
||||
hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA));
|
||||
hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB));
|
||||
}
|
||||
|
||||
static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_gdn_context * gctx = (struct htp_gdn_context *) data;
|
||||
struct htp_ops_context * octx = gctx->octx;
|
||||
|
||||
const struct htp_tensor * q = octx->src[0];
|
||||
const struct htp_tensor * k = octx->src[1];
|
||||
const struct htp_tensor * v = octx->src[2];
|
||||
const struct htp_tensor * g = octx->src[3];
|
||||
const struct htp_tensor * beta = octx->src[4];
|
||||
const struct htp_tensor * state = octx->src[5];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
const uint32_t S_v = v->ne[0];
|
||||
const uint32_t H = v->ne[1];
|
||||
const uint32_t n_tokens = v->ne[2];
|
||||
const uint32_t n_seqs = v->ne[3];
|
||||
|
||||
const uint32_t total_rows = H * n_seqs;
|
||||
if (ith >= total_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t rq3 = n_seqs / q->ne[3];
|
||||
const uint32_t rk3 = n_seqs / k->ne[3];
|
||||
const float scale = 1.0f / sqrtf((float) S_v);
|
||||
|
||||
float * dst_base = (float *) (uintptr_t) dst->data;
|
||||
float * state_out_base = dst_base + (uint64_t) S_v * H * n_tokens * n_seqs;
|
||||
const float * state_in_base = (const float *) (uintptr_t) state->data;
|
||||
|
||||
const bool kda = (g->ne[0] == S_v);
|
||||
float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
|
||||
float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
|
||||
float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
|
||||
float local_sums[4] __attribute__((aligned(128)));
|
||||
|
||||
for (uint32_t ir = ith; ir < total_rows; ir += nth) {
|
||||
const uint32_t iv1 = ir % H;
|
||||
const uint32_t iv3 = ir / H;
|
||||
|
||||
const uint32_t iq1 = iv1 % q->ne[1];
|
||||
const uint32_t ik1 = iv1 % k->ne[1];
|
||||
const uint32_t iq3 = iv3 / rq3;
|
||||
const uint32_t ik3 = iv3 / rk3;
|
||||
|
||||
float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
|
||||
const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
|
||||
|
||||
memcpy(s_out, s_in, gctx->state_bytes);
|
||||
float * s_work = s_out;
|
||||
|
||||
float * attn_data = dst_base + ((uint64_t) iv3 * n_tokens * H + iv1) * S_v;
|
||||
|
||||
for (uint32_t t = 0; t < n_tokens; ++t) {
|
||||
const float * q_t = (const float *) ((const uint8_t *) (uintptr_t) q->data +
|
||||
(uint64_t) iq3 * q->nb[3] + (uint64_t) t * q->nb[2] + (uint64_t) iq1 * q->nb[1]);
|
||||
const float * k_t = (const float *) ((const uint8_t *) (uintptr_t) k->data +
|
||||
(uint64_t) ik3 * k->nb[3] + (uint64_t) t * k->nb[2] + (uint64_t) ik1 * k->nb[1]);
|
||||
const float * v_t = (const float *) ((const uint8_t *) (uintptr_t) v->data +
|
||||
(uint64_t) iv3 * v->nb[3] + (uint64_t) t * v->nb[2] + (uint64_t) iv1 * v->nb[1]);
|
||||
const float * g_t = (const float *) ((const uint8_t *) (uintptr_t) g->data +
|
||||
(uint64_t) iv3 * g->nb[3] + (uint64_t) t * g->nb[2] + (uint64_t) iv1 * g->nb[1]);
|
||||
const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data +
|
||||
(uint64_t) iv3 * beta->nb[3] + (uint64_t) t * beta->nb[2] + (uint64_t) iv1 * beta->nb[1]);
|
||||
|
||||
memcpy(local_q, q_t, (size_t) S_v * sizeof(float));
|
||||
memcpy(local_k, k_t, (size_t) S_v * sizeof(float));
|
||||
|
||||
if (kda) {
|
||||
hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false);
|
||||
|
||||
uint32_t j = 0;
|
||||
for (; j + 4 <= S_v; j += 4) {
|
||||
float * row0 = s_work + (uint64_t) (j + 0) * S_v;
|
||||
float * row1 = s_work + (uint64_t) (j + 1) * S_v;
|
||||
float * row2 = s_work + (uint64_t) (j + 2) * S_v;
|
||||
float * row3 = s_work + (uint64_t) (j + 3) * S_v;
|
||||
gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums);
|
||||
float local_delta_b[4] __attribute__((aligned(128)));
|
||||
for (uint32_t r = 0; r < 4; ++r) {
|
||||
local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
|
||||
}
|
||||
gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
|
||||
for (uint32_t r = 0; r < 4; ++r) {
|
||||
attn_data[j + r] = local_sums[r] * scale;
|
||||
}
|
||||
}
|
||||
for (; j < S_v; ++j) {
|
||||
float * row = s_work + (uint64_t) j * S_v;
|
||||
const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v);
|
||||
const float dj = (v_t[j] - sum) * beta_val;
|
||||
attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale;
|
||||
}
|
||||
} else {
|
||||
const float gate = expf(g_t[0]);
|
||||
uint32_t j = 0;
|
||||
for (; j + 4 <= S_v; j += 4) {
|
||||
float * row0 = s_work + (uint64_t) (j + 0) * S_v;
|
||||
float * row1 = s_work + (uint64_t) (j + 1) * S_v;
|
||||
float * row2 = s_work + (uint64_t) (j + 2) * S_v;
|
||||
float * row3 = s_work + (uint64_t) (j + 3) * S_v;
|
||||
gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums);
|
||||
float local_delta_b[4] __attribute__((aligned(128)));
|
||||
for (uint32_t r = 0; r < 4; ++r) {
|
||||
local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
|
||||
}
|
||||
gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
|
||||
for (uint32_t r = 0; r < 4; ++r) {
|
||||
attn_data[j + r] = local_sums[r] * scale;
|
||||
}
|
||||
}
|
||||
for (; j < S_v; ++j) {
|
||||
float * row = s_work + (uint64_t) j * S_v;
|
||||
const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v);
|
||||
const float dj = (v_t[j] - sum) * beta_val;
|
||||
attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale;
|
||||
}
|
||||
}
|
||||
|
||||
attn_data += (uint64_t) S_v * H;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_gdn_context * gctx = (struct htp_gdn_context *) data;
|
||||
struct htp_ops_context * octx = gctx->octx;
|
||||
|
||||
const struct htp_tensor * q = octx->src[0];
|
||||
const struct htp_tensor * k = octx->src[1];
|
||||
const struct htp_tensor * v = octx->src[2];
|
||||
const struct htp_tensor * g = octx->src[3];
|
||||
const struct htp_tensor * beta = octx->src[4];
|
||||
const struct htp_tensor * state = octx->src[5];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
const uint32_t S_v = v->ne[0];
|
||||
const uint32_t H = v->ne[1];
|
||||
const uint32_t n_seqs = v->ne[3];
|
||||
|
||||
const uint32_t total_rows = H * n_seqs;
|
||||
if (ith >= total_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t rq3 = n_seqs / q->ne[3];
|
||||
const uint32_t rk3 = n_seqs / k->ne[3];
|
||||
const float scale = 1.0f / sqrtf((float) S_v);
|
||||
|
||||
float * dst_base = (float *) (uintptr_t) dst->data;
|
||||
float * state_out_base = dst_base + (uint64_t) S_v * H * n_seqs;
|
||||
const float * state_in_base = (const float *) (uintptr_t) state->data;
|
||||
|
||||
const bool kda = (g->ne[0] == S_v);
|
||||
float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
|
||||
float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
|
||||
float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
|
||||
float local_sums[8] __attribute__((aligned(128)));
|
||||
|
||||
dma_queue * dma = octx->ctx->dma[ith];
|
||||
|
||||
uint8_t * spad = NULL;
|
||||
if (gctx->use_vtcm) {
|
||||
spad = gctx->vtcm_state_base + gctx->vtcm_state_per_thread * ith;
|
||||
}
|
||||
|
||||
for (uint32_t ir = ith; ir < total_rows; ir += nth) {
|
||||
const uint32_t iv1 = ir % H;
|
||||
const uint32_t iv3 = ir / H;
|
||||
|
||||
const uint32_t iq1 = iv1 % q->ne[1];
|
||||
const uint32_t ik1 = iv1 % k->ne[1];
|
||||
const uint32_t iq3 = iv3 / rq3;
|
||||
const uint32_t ik3 = iv3 / rk3;
|
||||
|
||||
float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
|
||||
const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
|
||||
float * s_work;
|
||||
|
||||
if (spad) {
|
||||
dma_queue_push(dma, dma_make_ptr(spad, s_in),
|
||||
S_v * sizeof(float), S_v * sizeof(float),
|
||||
S_v * sizeof(float), S_v);
|
||||
dma_queue_pop(dma);
|
||||
s_work = (float *) spad;
|
||||
} else {
|
||||
s_work = s_out;
|
||||
memcpy(s_work, s_in, gctx->state_bytes);
|
||||
}
|
||||
|
||||
float * attn_data = dst_base + ((uint64_t) iv3 * H + iv1) * S_v;
|
||||
|
||||
const float * q_t = (const float *) ((const uint8_t *) (uintptr_t) q->data +
|
||||
(uint64_t) iq3 * q->nb[3] + (uint64_t) iq1 * q->nb[1]);
|
||||
const float * k_t = (const float *) ((const uint8_t *) (uintptr_t) k->data +
|
||||
(uint64_t) ik3 * k->nb[3] + (uint64_t) ik1 * k->nb[1]);
|
||||
const float * v_t = (const float *) ((const uint8_t *) (uintptr_t) v->data +
|
||||
(uint64_t) iv3 * v->nb[3] + (uint64_t) iv1 * v->nb[1]);
|
||||
const float * g_t = (const float *) ((const uint8_t *) (uintptr_t) g->data +
|
||||
(uint64_t) iv3 * g->nb[3] + (uint64_t) iv1 * g->nb[1]);
|
||||
const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data +
|
||||
(uint64_t) iv3 * beta->nb[3] + (uint64_t) iv1 * beta->nb[1]);
|
||||
|
||||
memcpy(local_q, q_t, (size_t) S_v * sizeof(float));
|
||||
memcpy(local_k, k_t, (size_t) S_v * sizeof(float));
|
||||
|
||||
if (kda) {
|
||||
hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false);
|
||||
|
||||
uint32_t j = 0;
|
||||
for (; j + 8 <= S_v; j += 8) {
|
||||
float * row0 = s_work + (uint64_t) (j + 0) * S_v;
|
||||
float * row1 = s_work + (uint64_t) (j + 1) * S_v;
|
||||
float * row2 = s_work + (uint64_t) (j + 2) * S_v;
|
||||
float * row3 = s_work + (uint64_t) (j + 3) * S_v;
|
||||
float * row4 = s_work + (uint64_t) (j + 4) * S_v;
|
||||
float * row5 = s_work + (uint64_t) (j + 5) * S_v;
|
||||
float * row6 = s_work + (uint64_t) (j + 6) * S_v;
|
||||
float * row7 = s_work + (uint64_t) (j + 7) * S_v;
|
||||
gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
|
||||
local_gate, local_k, S_v, local_sums);
|
||||
float local_delta_b[8] __attribute__((aligned(128)));
|
||||
for (uint32_t r = 0; r < 8; ++r) {
|
||||
local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
|
||||
}
|
||||
gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
|
||||
local_k, local_delta_b, local_q, S_v, local_sums);
|
||||
for (uint32_t r = 0; r < 8; ++r) {
|
||||
attn_data[j + r] = local_sums[r] * scale;
|
||||
}
|
||||
}
|
||||
for (; j + 4 <= S_v; j += 4) {
|
||||
float * row0 = s_work + (uint64_t) (j + 0) * S_v;
|
||||
float * row1 = s_work + (uint64_t) (j + 1) * S_v;
|
||||
float * row2 = s_work + (uint64_t) (j + 2) * S_v;
|
||||
float * row3 = s_work + (uint64_t) (j + 3) * S_v;
|
||||
gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums);
|
||||
float local_delta_b[4] __attribute__((aligned(128)));
|
||||
for (uint32_t r = 0; r < 4; ++r) {
|
||||
local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
|
||||
}
|
||||
gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
|
||||
for (uint32_t r = 0; r < 4; ++r) {
|
||||
attn_data[j + r] = local_sums[r] * scale;
|
||||
}
|
||||
}
|
||||
for (; j < S_v; ++j) {
|
||||
float * row = s_work + (uint64_t) j * S_v;
|
||||
const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v);
|
||||
const float dj = (v_t[j] - sum) * beta_val;
|
||||
attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale;
|
||||
}
|
||||
} else {
|
||||
const float gate = expf(g_t[0]);
|
||||
uint32_t j = 0;
|
||||
for (; j + 8 <= S_v; j += 8) {
|
||||
float * row0 = s_work + (uint64_t) (j + 0) * S_v;
|
||||
float * row1 = s_work + (uint64_t) (j + 1) * S_v;
|
||||
float * row2 = s_work + (uint64_t) (j + 2) * S_v;
|
||||
float * row3 = s_work + (uint64_t) (j + 3) * S_v;
|
||||
float * row4 = s_work + (uint64_t) (j + 4) * S_v;
|
||||
float * row5 = s_work + (uint64_t) (j + 5) * S_v;
|
||||
float * row6 = s_work + (uint64_t) (j + 6) * S_v;
|
||||
float * row7 = s_work + (uint64_t) (j + 7) * S_v;
|
||||
gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
|
||||
gate, local_k, S_v, local_sums);
|
||||
float local_delta_b[8] __attribute__((aligned(128)));
|
||||
for (uint32_t r = 0; r < 8; ++r) {
|
||||
local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
|
||||
}
|
||||
gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
|
||||
local_k, local_delta_b, local_q, S_v, local_sums);
|
||||
for (uint32_t r = 0; r < 8; ++r) {
|
||||
attn_data[j + r] = local_sums[r] * scale;
|
||||
}
|
||||
}
|
||||
for (; j + 4 <= S_v; j += 4) {
|
||||
float * row0 = s_work + (uint64_t) (j + 0) * S_v;
|
||||
float * row1 = s_work + (uint64_t) (j + 1) * S_v;
|
||||
float * row2 = s_work + (uint64_t) (j + 2) * S_v;
|
||||
float * row3 = s_work + (uint64_t) (j + 3) * S_v;
|
||||
gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums);
|
||||
float local_delta_b[4] __attribute__((aligned(128)));
|
||||
for (uint32_t r = 0; r < 4; ++r) {
|
||||
local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
|
||||
}
|
||||
gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
|
||||
for (uint32_t r = 0; r < 4; ++r) {
|
||||
attn_data[j + r] = local_sums[r] * scale;
|
||||
}
|
||||
}
|
||||
for (; j < S_v; ++j) {
|
||||
float * row = s_work + (uint64_t) j * S_v;
|
||||
const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v);
|
||||
const float dj = (v_t[j] - sum) * beta_val;
|
||||
attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale;
|
||||
}
|
||||
}
|
||||
|
||||
if (spad) {
|
||||
dma_queue_push(dma, dma_make_ptr(s_out, spad),
|
||||
S_v * sizeof(float), S_v * sizeof(float),
|
||||
S_v * sizeof(float), S_v);
|
||||
dma_queue_pop(dma);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int op_gated_delta_net(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * q = octx->src[0];
|
||||
const struct htp_tensor * k = octx->src[1];
|
||||
const struct htp_tensor * v = octx->src[2];
|
||||
const struct htp_tensor * g = octx->src[3];
|
||||
const struct htp_tensor * beta = octx->src[4];
|
||||
const struct htp_tensor * state = octx->src[5];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
if (!q || !k || !v || !g || !beta || !state || !dst) {
|
||||
return HTP_STATUS_INVAL_PARAMS;
|
||||
}
|
||||
|
||||
if (q->type != HTP_TYPE_F32 || k->type != HTP_TYPE_F32 || v->type != HTP_TYPE_F32 ||
|
||||
g->type != HTP_TYPE_F32 || beta->type != HTP_TYPE_F32 || state->type != HTP_TYPE_F32 ||
|
||||
dst->type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const uint32_t S_v = v->ne[0];
|
||||
const uint32_t H = v->ne[1];
|
||||
const uint32_t n_tokens = v->ne[2];
|
||||
const uint32_t n_seqs = v->ne[3];
|
||||
|
||||
if (S_v == 0 || S_v > HTP_GDN_MAX_SV || H == 0 || n_tokens == 0 || n_seqs == 0) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
if (q->ne[0] != S_v || k->ne[0] != S_v || q->ne[1] == 0 || k->ne[1] == 0 ||
|
||||
q->ne[2] != n_tokens || k->ne[2] != n_tokens || q->ne[3] == 0 || k->ne[3] == 0 ||
|
||||
(n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
if (state->ne[0] * state->ne[1] * state->ne[2] * state->ne[3] != S_v * S_v * H * n_seqs) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
struct htp_gdn_context gctx;
|
||||
gctx.octx = octx;
|
||||
gctx.rows_per_thread = (H * n_seqs + octx->n_threads - 1) / octx->n_threads;
|
||||
gctx.state_bytes = (size_t) S_v * S_v * sizeof(float);
|
||||
|
||||
size_t state_aligned = (size_t) S_v * S_v * sizeof(float);
|
||||
state_aligned = (state_aligned + 127) & ~(size_t)127;
|
||||
|
||||
gctx.use_vtcm = false;
|
||||
gctx.vtcm_state_base = NULL;
|
||||
gctx.vtcm_state_per_thread = 0;
|
||||
|
||||
if (n_tokens == 1 && octx->ctx->vtcm_base) {
|
||||
size_t vtcm_total = state_aligned * octx->n_threads;
|
||||
if (octx->ctx->vtcm_size >= vtcm_total) {
|
||||
gctx.use_vtcm = true;
|
||||
gctx.vtcm_state_base = octx->ctx->vtcm_base;
|
||||
gctx.vtcm_state_per_thread = state_aligned;
|
||||
}
|
||||
}
|
||||
|
||||
if (n_tokens == 1) {
|
||||
worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_tg_thread, &gctx, octx->n_threads);
|
||||
} else {
|
||||
worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_pp_thread, &gctx, octx->n_threads);
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
@@ -74,6 +74,12 @@ static inline size_t hex_smax(size_t a, size_t b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
static inline void hex_swap_ptr(void ** p1, void ** p2) {
|
||||
void * t = *p1;
|
||||
*p1 = *p2;
|
||||
*p2 = t;
|
||||
}
|
||||
|
||||
static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
|
||||
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
|
||||
Q6_l2fetch_AP((void *) p, control);
|
||||
|
||||
1840
ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c
Normal file
1840
ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -61,6 +61,9 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx,
|
||||
int m, int k, int n,
|
||||
int weight_type);
|
||||
|
||||
// HMX flash attention
|
||||
int hmx_flash_attn_ext(struct htp_ops_context * octx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -4,6 +4,9 @@
|
||||
#ifndef HMX_UTILS_H
|
||||
#define HMX_UTILS_H
|
||||
|
||||
#include "hvx-base.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <stddef.h>
|
||||
|
||||
@@ -12,21 +15,188 @@
|
||||
#define HMX_FP16_TILE_N_ELMS 1024
|
||||
#define HMX_FP16_TILE_SIZE 2048
|
||||
|
||||
#define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline))
|
||||
|
||||
// Initialise aligned 256-byte area with scale vector + zero padding.
|
||||
static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) {
|
||||
HVX_Vector *pv = (HVX_Vector *)out_scales;
|
||||
*pv++ = v_scale;
|
||||
*pv = Q6_V_vzero();
|
||||
static inline void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) {
|
||||
volatile HVX_Vector *pv = (HVX_Vector *) out_scales;
|
||||
pv[0] = v_scale;
|
||||
pv[1] = Q6_V_vzero();
|
||||
}
|
||||
|
||||
// --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) ---
|
||||
// --- Shared scatter offsets and interleave helper ---
|
||||
|
||||
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {
|
||||
uint8_t *p = *vtcm_ptr;
|
||||
*vtcm_ptr += size;
|
||||
return p;
|
||||
// vscatter offsets for fused dequant+transpose: write K-values directly to [K][N] tile.
|
||||
// word[i] = i*128 maps K-row-pair i to byte offset i*128.
|
||||
// Column offset (n*4) is added at runtime. Entries 0..15 cover one tile (region 2047);
|
||||
// entries 16..31 cover the next adjacent tile (region 4095) — pick region size at the
|
||||
// call site to scatter into one tile (masked) or two contiguous tiles (unmasked).
|
||||
static const int32_t hmx_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = {
|
||||
0 * 128, 1 * 128, 2 * 128, 3 * 128, 4 * 128, 5 * 128, 6 * 128, 7 * 128, 8 * 128, 9 * 128, 10 * 128,
|
||||
11 * 128, 12 * 128, 13 * 128, 14 * 128, 15 * 128, 16 * 128, 17 * 128, 18 * 128, 19 * 128, 20 * 128, 21 * 128,
|
||||
22 * 128, 23 * 128, 24 * 128, 25 * 128, 26 * 128, 27 * 128, 28 * 128, 29 * 128, 30 * 128, 31 * 128,
|
||||
};
|
||||
|
||||
// Scatter row-major FP16 data (in VTCM scratch) into transposed [K][N] tiles.
|
||||
// vtcm_src: [n_cols][src_stride] row-major fp16 (only first k elements per row are used)
|
||||
// vtcm_dst: [n_col_tiles][n_k_tiles][HMX_FP16_TILE_N_ELMS] tile-major interleaved fp16
|
||||
// Processes rows [start_row, end_row) for multi-thread slicing.
|
||||
// Full range: start_row=0, end_row=n_cols.
|
||||
static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
const __fp16 * restrict vtcm_src,
|
||||
int n_cols,
|
||||
int k,
|
||||
int src_stride,
|
||||
int start_row,
|
||||
int end_row) {
|
||||
assert(k % HMX_FP16_TILE_N_COLS == 0);
|
||||
|
||||
const int n_k_tiles = k / HMX_FP16_TILE_N_COLS;
|
||||
const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets);
|
||||
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4);
|
||||
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64);
|
||||
// Each hvx_vmemu load brings 64 fp16 = 128 bytes covering 2 adjacent K-tiles.
|
||||
// When n_k_tiles is even, scatter into 2 K-tiles per call (region 4095, no mask)
|
||||
// using the upper half of hmx_transpose_scatter_offsets. Tail one K-tile (when
|
||||
// n_k_tiles is odd) falls back to single-tile masked scatter.
|
||||
const bool pair_scatter = (n_k_tiles & 1) == 0;
|
||||
const size_t pair_region = (size_t) (2 * HMX_FP16_TILE_SIZE - 1);
|
||||
const size_t single_region = (size_t) (HMX_FP16_TILE_SIZE - 1);
|
||||
__builtin_assume(k > 0);
|
||||
__builtin_assume(end_row > start_row);
|
||||
|
||||
if (pair_scatter) {
|
||||
// Step c by 64 fp16 (two K-tiles per scatter), advance dst by 2 tiles per iter.
|
||||
const int c_step = 2 * HMX_FP16_TILE_N_COLS;
|
||||
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
|
||||
const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS;
|
||||
const int n_c_iters = k / c_step;
|
||||
|
||||
for (int r = start_row; r < end_row; r += 2) {
|
||||
const int ct = r / HMX_FP16_TILE_N_ROWS;
|
||||
const int local_r = r % HMX_FP16_TILE_N_ROWS;
|
||||
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
|
||||
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
|
||||
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
|
||||
|
||||
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
|
||||
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
|
||||
|
||||
if (p1) {
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0);
|
||||
p0 += c_byte_step;
|
||||
HVX_Vector v1 = hvx_vmemu(p1);
|
||||
p1 += c_byte_step;
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, v1);
|
||||
tile_base += dst_step;
|
||||
}
|
||||
} else {
|
||||
const HVX_Vector vzero = Q6_V_vzero();
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0);
|
||||
p0 += c_byte_step;
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero);
|
||||
tile_base += dst_step;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback: scatter one K-tile per call (region 2047, masked).
|
||||
const int c_step = HMX_FP16_TILE_N_COLS;
|
||||
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
|
||||
const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS;
|
||||
const int n_c_iters = k / c_step;
|
||||
|
||||
for (int r = start_row; r < end_row; r += 2) {
|
||||
const int ct = r / HMX_FP16_TILE_N_ROWS;
|
||||
const int local_r = r % HMX_FP16_TILE_N_ROWS;
|
||||
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
|
||||
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
|
||||
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
|
||||
|
||||
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
|
||||
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
|
||||
|
||||
if (p1) {
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0);
|
||||
p0 += c_byte_step;
|
||||
HVX_Vector v1 = hvx_vmemu(p1);
|
||||
p1 += c_byte_step;
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, v1);
|
||||
tile_base += dst_step;
|
||||
}
|
||||
} else {
|
||||
const HVX_Vector vzero = Q6_V_vzero();
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0);
|
||||
p0 += c_byte_step;
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero);
|
||||
tile_base += dst_step;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Interleave row-major FP16 data into column-major tile format.
|
||||
// Input: [n_rows, head_dim] row-major. Output: tile[dim_tile][row_tile].
|
||||
// Processes rows [start_row, end_row) for multi-thread slicing.
|
||||
// Full range: start_row=0, end_row=n_rows.
|
||||
static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
|
||||
const __fp16 * restrict src,
|
||||
int n_rows,
|
||||
int head_dim,
|
||||
int src_stride,
|
||||
int n_row_tiles,
|
||||
int start_row,
|
||||
int end_row) {
|
||||
__builtin_assume(head_dim > 0);
|
||||
const size_t tile_stride_elms = (size_t) n_row_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
|
||||
for (int r = start_row; r < end_row; r += 2) {
|
||||
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_rows;
|
||||
|
||||
const HVX_Vector * pv_in0 = (const HVX_Vector *) (src + r * src_stride);
|
||||
const HVX_Vector * pv_in1 = next_row_valid ? (const HVX_Vector *) (src + (r + 1) * src_stride) : NULL;
|
||||
|
||||
// Row-pair invariants hoisted out of the c loop.
|
||||
const int r0 = r / HMX_FP16_TILE_N_ROWS;
|
||||
const int r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2;
|
||||
|
||||
// tb0 starts at tile (c0=0, r0); tb1 at the adjacent dim-tile (c0=1, r0).
|
||||
// Each c step (+= 64) advances both by 2 dim-tiles worth of fp16.
|
||||
__fp16 * tb0 = tiles_out + (size_t) r0 * HMX_FP16_TILE_N_ELMS;
|
||||
__fp16 * tb1 = tb0 + tile_stride_elms;
|
||||
const size_t tb_step = 2 * tile_stride_elms;
|
||||
|
||||
if (pv_in1) {
|
||||
for (int c = 0; c < head_dim; c += 64) {
|
||||
HVX_Vector v0 = *pv_in0++;
|
||||
HVX_Vector v1 = *pv_in1++;
|
||||
HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2);
|
||||
((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp);
|
||||
((HVX_Vector *) tb1)[r1_half] = Q6_V_hi_W(vp);
|
||||
tb0 += tb_step;
|
||||
tb1 += tb_step;
|
||||
}
|
||||
} else {
|
||||
const HVX_Vector vzero = Q6_V_vzero();
|
||||
for (int c = 0; c < head_dim; c += 64) {
|
||||
HVX_Vector v0 = *pv_in0++;
|
||||
HVX_VectorPair vp = Q6_W_vshuff_VVR(vzero, v0, -2);
|
||||
((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp);
|
||||
((HVX_Vector *) tb1)[r1_half] = Q6_V_hi_W(vp);
|
||||
tb0 += tb_step;
|
||||
tb1 += tb_step;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // HMX_UTILS_H
|
||||
|
||||
@@ -20,7 +20,7 @@ struct htp_mmap {
|
||||
uint64_t size;
|
||||
uint64_t base;
|
||||
uint32_t fd;
|
||||
uint32_t pinned;
|
||||
uint32_t reserved;
|
||||
};
|
||||
|
||||
// Scratchpad state
|
||||
@@ -77,6 +77,8 @@ struct htp_context {
|
||||
atomic_bool vtcm_valid;
|
||||
atomic_bool vtcm_needs_release;
|
||||
|
||||
uint64_t max_vmem;
|
||||
|
||||
struct htp_ops_context octx;
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
@@ -104,5 +106,6 @@ int op_cumsum(struct htp_ops_context * octx);
|
||||
int op_fill(struct htp_ops_context * octx);
|
||||
int op_diag(struct htp_ops_context * octx);
|
||||
int op_solve_tri(struct htp_ops_context * octx);
|
||||
int op_gated_delta_net(struct htp_ops_context * octx);
|
||||
|
||||
#endif /* HTP_CTX_H */
|
||||
|
||||
@@ -83,6 +83,9 @@ enum htp_op_code {
|
||||
HTP_OP_FILL,
|
||||
HTP_OP_DIAG,
|
||||
HTP_OP_SOLVE_TRI,
|
||||
HTP_OP_L2_NORM,
|
||||
HTP_OP_GATED_DELTA_NET,
|
||||
|
||||
HTP_OP_INVALID
|
||||
};
|
||||
|
||||
@@ -90,15 +93,11 @@ enum htp_op_code {
|
||||
#define HTP_OP_MAX_INPUTS 6 // aka GGML_MAX_SRCS
|
||||
#define HTP_OP_MAX_PARAMS 16 // aka GGML_MAX_OP_PARAMS
|
||||
|
||||
#define HTP_OP_MAX_BUFS 8
|
||||
#define HTP_OP_MAX_BUFS 16
|
||||
#define HTP_OP_MAX_REQS 256
|
||||
#define HTP_OP_MAX_TENSORS (HTP_OP_MAX_REQS * HTP_OP_MAX_INPUTS + HTP_OP_MAX_REQS)
|
||||
|
||||
#if __HVX_ARCH__ < 75
|
||||
#define HTP_OP_MAX_VMEM (3167538380u)
|
||||
#else
|
||||
#define HTP_OP_MAX_VMEM (3221225472u)
|
||||
#endif
|
||||
#define HTP_OP_MAX_VMEM_DEFAULT (3355443200u)
|
||||
|
||||
#define HTP_MMAP_MAX_VMEM (2147483648u)
|
||||
|
||||
|
||||
@@ -11,9 +11,9 @@ struct htp_iface_pmu_conf {
|
||||
};
|
||||
|
||||
interface htp_iface : remote_handle64 {
|
||||
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx);
|
||||
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx, in uint64 max_vmem);
|
||||
AEEResult stop();
|
||||
AEEResult mmap(in uint32 fd, in uint32 size, in uint32 pinned);
|
||||
AEEResult mmap(in uint32 fd, in uint32 size);
|
||||
AEEResult munmap(in uint32 fd);
|
||||
AEEResult profiler(in uint32 mode, in htp_iface_pmu_conf pmu);
|
||||
AEEResult etm(in uint32 enable);
|
||||
|
||||
@@ -77,6 +77,12 @@ static inline int32_t hvx_vec_get_i32(HVX_Vector v) {
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline _Float16 hvx_vec_get_f16(HVX_Vector v) {
|
||||
_Float16 __attribute__((aligned(128))) x;
|
||||
hvx_vec_store_a(&x, 2, v);
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) {
|
||||
// abs by clearing the fp16 sign bit
|
||||
HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);
|
||||
|
||||
@@ -7,7 +7,8 @@
|
||||
|
||||
#include "hvx-base.h"
|
||||
|
||||
#define hvx_splat_loop_body(dst_type, vec_store) \
|
||||
#define hvx_splat_pragma(x) _Pragma(#x)
|
||||
#define hvx_splat_loop_body(dst_type, vec_store, unroll_cnt) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
\
|
||||
@@ -16,7 +17,7 @@
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
hvx_splat_pragma(unroll(unroll_cnt)) \
|
||||
for (; i < nvec; i++) { \
|
||||
vdst[i] = src; \
|
||||
} \
|
||||
@@ -25,31 +26,47 @@
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_splat_a(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
|
||||
static inline void hvx_splat_a(void * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a);
|
||||
hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a, 4);
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
|
||||
hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u);
|
||||
static inline void hvx_splat_u(void * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
|
||||
hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u, 4);
|
||||
}
|
||||
|
||||
static inline void hvx_splat_f32_a(uint8_t * restrict dst, float v, uint32_t n) {
|
||||
static inline void hvx_splat_f32_a(void * restrict dst, float v, uint32_t n) {
|
||||
hvx_splat_a(dst, hvx_vec_splat_f32(v), n, sizeof(float));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_f32_u(uint8_t * restrict dst, float v, uint32_t n) {
|
||||
static inline void hvx_splat_f32_u(void * restrict dst, float v, uint32_t n) {
|
||||
hvx_splat_u(dst, hvx_vec_splat_f32(v), n, sizeof(float));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_f16_a(uint8_t * restrict dst, _Float16 v, uint32_t n) {
|
||||
static inline void hvx_splat_f16_a(void * restrict dst, _Float16 v, uint32_t n) {
|
||||
hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_f16_u(uint8_t * restrict dst, _Float16 v, uint32_t n) {
|
||||
static inline void hvx_splat_f16_u(void * restrict dst, _Float16 v, uint32_t n) {
|
||||
hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u16_a(void * restrict dst, uint16_t v, uint32_t n) {
|
||||
hvx_splat_a(dst, Q6_Vh_vsplat_R(v), n, sizeof(uint16_t));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u16_u(void * restrict dst, uint16_t v, uint32_t n) {
|
||||
hvx_splat_u(dst, Q6_Vh_vsplat_R(v), n, sizeof(uint16_t));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u8_a(void * restrict dst, uint8_t v, uint32_t n) {
|
||||
hvx_splat_a(dst, Q6_Vb_vsplat_R(v), n, 1);
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u8_u(void * restrict dst, uint8_t v, uint32_t n) {
|
||||
hvx_splat_u(dst, Q6_Vb_vsplat_R(v), n, 1);
|
||||
}
|
||||
|
||||
#define hvx_copy_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
|
||||
#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
|
||||
#define EXP_ONE (0x3f800000) // 1.0
|
||||
#define EXP_RANGE_R (0x42B16666) // 88.7
|
||||
#define EXP_RANGE_R (0x42B17218) // ln(FLT_MAX) approx = 88.7228
|
||||
#define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN))
|
||||
|
||||
static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
|
||||
@@ -163,7 +163,7 @@ static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict
|
||||
HVX_Vector vec_out = Q6_V_vzero();
|
||||
|
||||
static const float kInf = INFINITY;
|
||||
static const float kMaxExp = 88.7f;
|
||||
static const float kMaxExp = 88.7228f;
|
||||
|
||||
const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp);
|
||||
const HVX_Vector inf = hvx_vec_splat_f32(kInf);
|
||||
|
||||
@@ -210,7 +210,7 @@ AEEResult htp_iface_close(remote_handle64 handle) {
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32 pinned) {
|
||||
AEEResult htp_iface_mmap(remote_handle64 handle, uint32_t fd, uint32_t size) {
|
||||
struct htp_context * ctx = (struct htp_context *) handle;
|
||||
if (!ctx) {
|
||||
return AEE_EBADPARM;
|
||||
@@ -220,7 +220,6 @@ AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32
|
||||
for (uint32_t i=0; i<HTP_MAX_MMAPS; i++) {
|
||||
struct htp_mmap *m = &ctx->mmap[i];
|
||||
if (m->fd == fd) {
|
||||
m->pinned = pinned;
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
}
|
||||
@@ -229,7 +228,7 @@ AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32
|
||||
for (uint32_t i=0; i<HTP_MAX_MMAPS; i++) {
|
||||
struct htp_mmap *m = &ctx->mmap[i];
|
||||
if (!m->size) {
|
||||
FARF(HIGH, "mmap : fd %u size %u pinned %u", fd, size, pinned);
|
||||
FARF(HIGH, "mmap : fd %u size %u", fd, size);
|
||||
#if __HVX_ARCH__ > 73
|
||||
void *va = HAP_mmap2(NULL, size, HAP_PROT_READ | HAP_PROT_WRITE, 0, fd, 0);
|
||||
#else
|
||||
@@ -248,7 +247,6 @@ AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32
|
||||
m->base = (uint64_t) va;
|
||||
m->fd = fd;
|
||||
m->size = size;
|
||||
m->pinned = pinned;
|
||||
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
@@ -275,7 +273,6 @@ AEEResult htp_iface_munmap(remote_handle64 handle, uint32 fd) {
|
||||
m->size = 0;
|
||||
m->base = NULL;
|
||||
m->fd = -1;
|
||||
m->pinned = 0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -358,7 +355,7 @@ static void vtcm_free(struct htp_context * ctx) {
|
||||
static void htp_packet_callback(dspqueue_t queue, int error, void * context);
|
||||
static void htp_error_callback(dspqueue_t queue, int error, void * context);
|
||||
|
||||
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx) {
|
||||
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx, uint64_t max_vmem) {
|
||||
struct htp_context * ctx = (struct htp_context *) handle;
|
||||
|
||||
if (!ctx) {
|
||||
@@ -376,12 +373,12 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
|
||||
htp_error_callback, // Error callback; no errors expected on the DSP
|
||||
(void *) ctx, // Callback context
|
||||
&ctx->queue);
|
||||
|
||||
if (err) {
|
||||
FARF(ERROR, "Queue import failed with 0x%08x", (unsigned) err);
|
||||
return err;
|
||||
}
|
||||
|
||||
ctx->max_vmem = max_vmem;
|
||||
ctx->thread_id = qurt_thread_get_id();
|
||||
ctx->thread_prio = qurt_thread_get_priority(ctx->thread_id);
|
||||
|
||||
@@ -545,6 +542,7 @@ static int execute_op(struct htp_ops_context * octx) {
|
||||
case HTP_OP_UNARY_SIGMOID:
|
||||
case HTP_OP_UNARY_NEG:
|
||||
case HTP_OP_UNARY_EXP:
|
||||
case HTP_OP_L2_NORM:
|
||||
return op_unary(octx);
|
||||
|
||||
case HTP_OP_UNARY_SILU:
|
||||
@@ -596,6 +594,9 @@ static int execute_op(struct htp_ops_context * octx) {
|
||||
case HTP_OP_SOLVE_TRI:
|
||||
return op_solve_tri(octx);
|
||||
|
||||
case HTP_OP_GATED_DELTA_NET:
|
||||
return op_gated_delta_net(octx);
|
||||
|
||||
case HTP_OP_INVALID:
|
||||
break;
|
||||
|
||||
@@ -622,8 +623,8 @@ static inline bool reuse_buf(struct htp_context *ctx, uint32_t *m_reuse, struct
|
||||
}
|
||||
|
||||
static inline void drop_mmap(struct htp_context *ctx, struct htp_mmap *m) {
|
||||
if (m->size && !m->pinned) {
|
||||
FARF(HIGH, "unmap : fd %u base %p size %u pinned %u", m->fd, (void*) m->base, (uint32_t) m->size, m->pinned);
|
||||
if (m->size) {
|
||||
FARF(HIGH, "unmap : fd %u base %p size %u", m->fd, (void*) m->base, (uint32_t) m->size);
|
||||
#if __HVX_ARCH__ > 73
|
||||
HAP_munmap2((void *) m->base, m->size);
|
||||
#else
|
||||
@@ -660,9 +661,8 @@ static inline void mmap_buf(struct htp_context *ctx, struct htp_buf_desc *b) {
|
||||
m->base = b->base = (uint64_t) va;
|
||||
m->fd = b->fd;
|
||||
m->size = b->size;
|
||||
m->pinned = 0;
|
||||
|
||||
FARF(HIGH, "mmap : fd %u base %p size %u pinned %u", m->fd, (void*) m->base, (uint32_t) m->size, m->pinned);
|
||||
FARF(HIGH, "mmap : fd %u base %p size %u", m->fd, (void*) m->base, (uint32_t) m->size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -672,8 +672,8 @@ static void prep_op_bufs(struct htp_context *ctx, struct htp_buf_desc *bufs, uin
|
||||
uint32_t m_reuse = 0; // mmap reuse mask (index from ctx->mmap array)
|
||||
uint32_t b_reuse = 0; // buf reuse count
|
||||
|
||||
size_t m_vmem = 0; // mapped vmem
|
||||
size_t e_vmem = 0; // extra vmem
|
||||
uint64_t m_vmem = 0; // mapped vmem
|
||||
uint64_t e_vmem = 0; // extra vmem
|
||||
|
||||
// See what we can reuse
|
||||
for (uint32_t i=0; i < n_bufs; i++) {
|
||||
@@ -687,9 +687,10 @@ static void prep_op_bufs(struct htp_context *ctx, struct htp_buf_desc *bufs, uin
|
||||
// See how much vmem we have mmaped right now
|
||||
for (uint32_t i=0; i<HTP_MAX_MMAPS; i++) { m_vmem += ctx->mmap[i].size; }
|
||||
|
||||
FARF(HIGH, "prep-bufs : pass1 mmap-vmem %zu extra-vmem %zu n-bufs %u b-reuse %u", m_vmem, e_vmem, n_bufs, b_reuse);
|
||||
FARF(HIGH, "prep-bufs : pass1 mmap-vmem %zu extra-vmem %zu max-vmem %zu : n-bufs %u b-reuse %u",
|
||||
(size_t) m_vmem, (size_t) e_vmem, (size_t) ctx->max_vmem, n_bufs, b_reuse);
|
||||
|
||||
if ((m_vmem + e_vmem) > HTP_OP_MAX_VMEM) {
|
||||
if ((m_vmem + e_vmem) > ctx->max_vmem) {
|
||||
// Drop unused mappings
|
||||
for (uint32_t i=0; i < HTP_MAX_MMAPS; i++) {
|
||||
bool used = m_reuse & (1<<i);
|
||||
|
||||
@@ -2991,12 +2991,10 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
// M alignment: when M > 32 but not 32-aligned, we split into
|
||||
// HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows).
|
||||
// When M <= 32 and not 32-aligned, fall back entirely to HVX.
|
||||
// M alignment: Use HMX when M >= 32, the last partial tile (m_total % 32 rows)
|
||||
// is handled by HMX itself; when M < 32 fall back to HVX.
|
||||
const int m_total = (int) src1->ne[1];
|
||||
const int m_tail = m_total % 32;
|
||||
const int m_hmx = m_total - m_tail;
|
||||
const int m_hmx = m_total & ~31; // 0 when M < 32
|
||||
|
||||
if (m_hmx == 0) {
|
||||
return op_matmul_hvx(octx);
|
||||
@@ -3009,7 +3007,6 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
int k = (int) src0->ne[0]; // inner dimension
|
||||
int n = (int) src0->ne[1]; // weight columns
|
||||
|
||||
// --- Phase 1: HMX on the first m_hmx (32-aligned) rows ---
|
||||
int ret = -1;
|
||||
|
||||
// Row strides in elements. For compact tensors these equal k; for
|
||||
@@ -3027,7 +3024,7 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
.dst = (float *) dst->data,
|
||||
.activation = (float *) src1->data,
|
||||
.permuted_weight = (const __fp16 *) src0->data,
|
||||
.m = m_hmx,
|
||||
.m = m_total,
|
||||
.k = k,
|
||||
.n = n,
|
||||
.act_stride = act_stride,
|
||||
@@ -3048,12 +3045,12 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
} else {
|
||||
ret = hmx_mat_mul_permuted_w16a32(octx->ctx,
|
||||
(float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data,
|
||||
m_hmx, k, n, act_stride, wgt_stride);
|
||||
m_total, k, n, act_stride, wgt_stride);
|
||||
}
|
||||
} else {
|
||||
ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx,
|
||||
(float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data,
|
||||
m_hmx, k, n, (int) src0->type);
|
||||
m_total, k, n, (int) src0->type);
|
||||
}
|
||||
|
||||
if (ret != 0) {
|
||||
@@ -3061,27 +3058,6 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
return op_matmul(octx);
|
||||
}
|
||||
|
||||
// --- Phase 2: HVX on the remaining m_tail rows ---
|
||||
if (m_tail > 0) {
|
||||
// copy of src1 and dst
|
||||
struct htp_tensor src1_tail = *src1;
|
||||
struct htp_tensor dst_tail = *dst;
|
||||
|
||||
src1_tail.ne[1] = m_tail; // only tail rows
|
||||
dst_tail.ne[1] = m_tail; // only tail rows
|
||||
|
||||
// Offset activation and dst pointers past the HMX-processed rows.
|
||||
// Use nb[1] (row stride in bytes) to compute the byte offset.
|
||||
src1_tail.data += (uint32_t) m_hmx * src1->nb[1];
|
||||
dst_tail.data += (uint32_t) m_hmx * dst->nb[1];
|
||||
|
||||
octx->src[1] = &src1_tail;
|
||||
octx->dst = &dst_tail;
|
||||
|
||||
FARF(HIGH, "hmx-matmul: HVX tail m_tail %d src1 %p dst %p", m_tail, (void *) src1_tail.data, (void *) dst_tail.data);
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
return 0;
|
||||
#endif // HTP_HAS_HMX
|
||||
}
|
||||
|
||||
@@ -26,8 +26,8 @@ struct htp_unary_context {
|
||||
const uint8_t * data_src0;
|
||||
uint8_t * data_dst;
|
||||
|
||||
size_t src0_row_size;
|
||||
size_t dst_row_size;
|
||||
size_t src0_data_row_size; // actual data bytes per row
|
||||
size_t dst_data_row_size; // actual data bytes per row
|
||||
|
||||
size_t src0_row_size_aligned;
|
||||
size_t dst_row_size_aligned;
|
||||
@@ -41,6 +41,40 @@ struct htp_unary_context {
|
||||
uint32_t nc;
|
||||
};
|
||||
|
||||
// Convert flat row index to DDR byte offset using the tensor's actual strides.
|
||||
// ir = i1 + ne1*(i2 + ne2*i3) => offset = i1*nb1 + i2*nb2 + i3*nb3
|
||||
static inline size_t unary_row_offset(uint32_t ir,
|
||||
uint32_t ne1, uint32_t ne2,
|
||||
size_t nb1, size_t nb2, size_t nb3) {
|
||||
const uint32_t i1 = ir % ne1;
|
||||
const uint32_t i2 = (ir / ne1) % ne2;
|
||||
const uint32_t i3 = ir / (ne1 * ne2);
|
||||
return i1 * nb1 + i2 * nb2 + i3 * nb3;
|
||||
}
|
||||
// Safe DMA block size from row `ir`: clamp to the tighter dim-1 slice
|
||||
// boundary of src and dst so the nb1 stride stays valid for all rows.
|
||||
static inline uint32_t unary_block_size(uint32_t ir,
|
||||
uint32_t end_row,
|
||||
uint32_t block,
|
||||
bool src_contig,
|
||||
bool dst_contig,
|
||||
uint32_t src_ne1,
|
||||
uint32_t dst_ne1) {
|
||||
uint32_t limit = MIN(block, end_row - ir);
|
||||
|
||||
if (!src_contig) {
|
||||
const uint32_t src_slice_end = (ir / src_ne1 + 1) * src_ne1;
|
||||
limit = MIN(limit, src_slice_end - ir);
|
||||
}
|
||||
|
||||
if (!dst_contig) {
|
||||
const uint32_t dst_slice_end = (ir / dst_ne1 + 1) * dst_ne1;
|
||||
limit = MIN(limit, dst_slice_end - ir);
|
||||
}
|
||||
|
||||
return limit;
|
||||
}
|
||||
|
||||
#define htp_unary_preamble \
|
||||
const uint32_t ne00 = src->ne[0]; \
|
||||
const uint32_t ne01 = src->ne[1]; \
|
||||
@@ -264,6 +298,81 @@ static void softplus_f32(const float * restrict src,
|
||||
}
|
||||
}
|
||||
|
||||
// --- L2_NORM HVX kernel ---
|
||||
// Computes y[i] = x[i] / fmax(sqrt(sum(x[j]^2)), epsilon) for each row.
|
||||
// scale = 1/fmax(sqrt(sum), epsilon) is computed entirely in HVX registers
|
||||
// using rsqrt + inverse to avoid scalar extraction.
|
||||
static void hvx_fast_l2_norm_f32(const uint8_t * restrict src,
|
||||
uint8_t * restrict dst,
|
||||
uint8_t * restrict pad,
|
||||
const int num_elems,
|
||||
float epsilon) {
|
||||
(void)pad;
|
||||
|
||||
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
||||
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
||||
|
||||
HVX_Vector sum_v = hvx_vec_splat_f32(0.0f);
|
||||
|
||||
const int nvec = num_elems / VLEN_FP32;
|
||||
const int nloe = num_elems % VLEN_FP32;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < nvec; i++) {
|
||||
HVX_Vector v1 = v_src[i];
|
||||
HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
||||
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, sq);
|
||||
}
|
||||
|
||||
// Include tail elements in the sum-of-squares using a predicate mask
|
||||
if (nloe > 0) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
||||
HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
||||
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, sq);
|
||||
}
|
||||
|
||||
// Compute scale = 1/fmax(sqrt(sum), epsilon) entirely in HVX registers.
|
||||
// hvx_vec_rsqrt_f32 + hvx_vec_inverse_f32 avoids scalar extraction.
|
||||
HVX_Vector sum_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
|
||||
HVX_Vector rsqrt_v = hvx_vec_rsqrt_f32(sum_sf); // 1/sqrt(sum)
|
||||
HVX_Vector sqrt_v = hvx_vec_inverse_f32(rsqrt_v); // sqrt(sum)
|
||||
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
|
||||
HVX_Vector denom_v = Q6_Vsf_vmax_VsfVsf(sqrt_v, epsilon_v); // fmax(sqrt(sum), epsilon)
|
||||
HVX_Vector scale_v = hvx_vec_inverse_f32(denom_v); // 1/fmax(sqrt(sum), epsilon)
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < nvec; i++) {
|
||||
HVX_Vector v1 = v_src[i];
|
||||
v_dst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(v1, scale_v));
|
||||
}
|
||||
|
||||
if (nloe > 0) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
||||
HVX_Vector result = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(v1, scale_v));
|
||||
hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
|
||||
}
|
||||
}
|
||||
|
||||
static void l2_norm_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
float epsilon = 0.f;
|
||||
memcpy(&epsilon, op_params, sizeof(float));
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size));
|
||||
float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size));
|
||||
|
||||
hvx_fast_l2_norm_f32((const uint8_t *)src_f, (uint8_t *)dst_f, spad, row_elems, epsilon);
|
||||
}
|
||||
}
|
||||
|
||||
static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
|
||||
struct htp_ops_context * octx = uctx->octx;
|
||||
@@ -276,8 +385,8 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
int32_t * op_params = octx->op_params;
|
||||
uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread;
|
||||
|
||||
const size_t src0_row_size = uctx->src0_row_size;
|
||||
const size_t dst_row_size = uctx->dst_row_size;
|
||||
const size_t src0_data_row_size = uctx->src0_data_row_size;
|
||||
const size_t dst_data_row_size = uctx->dst_data_row_size;
|
||||
|
||||
const size_t src0_row_size_aligned = uctx->src0_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = uctx->dst_row_size_aligned;
|
||||
@@ -303,7 +412,16 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
size_t src0_spad_half_size = uctx->src0_spad_half_size;
|
||||
size_t dst_spad_half_size = uctx->dst_spad_half_size;
|
||||
|
||||
const int BLOCK = uctx->block;
|
||||
// Non-contiguous tensors have gaps at dim-2/3 boundaries that a single-stride
|
||||
// 2D DMA descriptor cannot span. Clamp BLOCK to ne1 (one dim-1 slice) so every
|
||||
// transfer stays within a nb1-uniform region. Skipped for contiguous tensors.
|
||||
const bool src0_contig = (nb02 == (size_t)ne01 * nb01) &&
|
||||
(nb03 == (size_t)ne02 * nb02);
|
||||
const bool dst_contig = (nb2 == (size_t)ne1 * nb1) &&
|
||||
(nb3 == (size_t)ne2 * nb2);
|
||||
const uint32_t src0_max_block = src0_contig ? uctx->block : MIN((uint32_t)uctx->block, ne01);
|
||||
const uint32_t dst_max_block = dst_contig ? uctx->block : MIN((uint32_t)uctx->block, ne1);
|
||||
const uint32_t BLOCK = MIN(src0_max_block, dst_max_block);
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
@@ -312,21 +430,23 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
|
||||
dma_queue * dma_queue = octx->ctx->dma[ith];
|
||||
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; spad_idx++) {
|
||||
const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
||||
|
||||
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_queue_push(dma_queue,
|
||||
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
|
||||
dst_row_size, dst_row_size_aligned, 0);
|
||||
nb1, dst_row_size_aligned, dst_data_row_size, 0);
|
||||
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, block_size);
|
||||
const size_t src0_off = unary_row_offset(ir, ne01, ne02, nb01, nb02, nb03);
|
||||
dma_queue_push(dma_queue,
|
||||
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + src0_off),
|
||||
src0_row_size_aligned, nb01, src0_data_row_size, block_size);
|
||||
ir += block_size;
|
||||
}
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ) {
|
||||
const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
||||
|
||||
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
|
||||
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
|
||||
@@ -357,22 +477,32 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
case HTP_OP_UNARY_SOFTPLUS:
|
||||
softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_L2_NORM:
|
||||
l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),
|
||||
dst_row_size, dst_row_size_aligned, block_size);
|
||||
const size_t dst_off = unary_row_offset(ir, ne1, ne2, nb1, nb2, nb3);
|
||||
dma_queue_push(dma_queue,
|
||||
dma_make_ptr(data_dst + dst_off, dst_spad),
|
||||
nb1, dst_row_size_aligned, dst_data_row_size, block_size);
|
||||
|
||||
// prefetch N+2 loop iteration if any
|
||||
const uint32_t pref_block = (ir + BLOCK * 2);
|
||||
if (pref_block < src0_end_row) {
|
||||
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, pref_block_size);
|
||||
const uint32_t next_ir = ir + block_size;
|
||||
if (next_ir < src0_end_row) {
|
||||
const uint32_t next_block_size = unary_block_size(next_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
||||
const uint32_t pref_ir = next_ir + next_block_size;
|
||||
if (pref_ir < src0_end_row) {
|
||||
const uint32_t pref_block_size = unary_block_size(pref_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
||||
const size_t src0_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03);
|
||||
dma_queue_push(dma_queue,
|
||||
dma_make_ptr(src0_spad, data_src + src0_pref_off),
|
||||
src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size);
|
||||
}
|
||||
}
|
||||
ir += block_size;
|
||||
}
|
||||
|
||||
dma_queue_flush(dma_queue);
|
||||
@@ -417,6 +547,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
case HTP_OP_UNARY_SOFTPLUS:
|
||||
op_type = "softplus-f32";
|
||||
break;
|
||||
case HTP_OP_L2_NORM:
|
||||
op_type = "l2norm-f32";
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
|
||||
@@ -426,11 +559,11 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
|
||||
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
const size_t src0_data_row_size = src0->ne[0] * sizeof(float);
|
||||
const size_t dst_data_row_size = dst->ne[0] * sizeof(float);
|
||||
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_data_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_data_row_size, VLEN);
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
@@ -468,8 +601,8 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
.data_src0 = (const uint8_t *)src0->data,
|
||||
.data_dst = (uint8_t *)dst->data,
|
||||
|
||||
.src0_row_size = src0_row_size,
|
||||
.dst_row_size = dst_row_size,
|
||||
.src0_data_row_size = src0_data_row_size,
|
||||
.dst_data_row_size = dst_data_row_size,
|
||||
|
||||
.src0_row_size_aligned = src0_row_size_aligned,
|
||||
.dst_row_size_aligned = dst_row_size_aligned,
|
||||
|
||||
16
ggml/src/ggml-hexagon/htp/vtcm-utils.h
Normal file
16
ggml/src/ggml-hexagon/htp/vtcm-utils.h
Normal file
@@ -0,0 +1,16 @@
|
||||
#ifndef VTCM_UTILS_H
|
||||
#define VTCM_UTILS_H
|
||||
|
||||
#include "hex-utils.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <hexagon_types.h>
|
||||
|
||||
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {
|
||||
uint8_t *p = *vtcm_ptr;
|
||||
*vtcm_ptr += size;
|
||||
return p;
|
||||
}
|
||||
|
||||
#endif // VTCM_UTILS_H
|
||||
@@ -282,6 +282,7 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf);
|
||||
void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
|
||||
void ggml_metal_buffer_set_tensor (ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||
void ggml_metal_buffer_get_tensor (ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||
bool ggml_metal_buffer_cpy_tensor (ggml_metal_buffer_t buf, const struct ggml_tensor * src, struct ggml_tensor * dst);
|
||||
void ggml_metal_buffer_clear (ggml_metal_buffer_t buf, uint8_t value);
|
||||
|
||||
// finds the Metal buffer that contains the tensor data on the GPU device
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#import "ggml-metal-device.h"
|
||||
|
||||
#import "ggml-impl.h"
|
||||
#import "ggml-backend-impl.h"
|
||||
|
||||
#include <Foundation/Foundation.h>
|
||||
|
||||
@@ -1737,6 +1738,47 @@ void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_ten
|
||||
}
|
||||
}
|
||||
|
||||
bool ggml_metal_buffer_cpy_tensor(ggml_metal_buffer_t buf_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) {
|
||||
ggml_metal_buffer_t buf_src = (ggml_metal_buffer_t)src->buffer->context;
|
||||
|
||||
const size_t size = ggml_nbytes(src);
|
||||
|
||||
// if both buffers are shared, we can use memcpy directly
|
||||
if (buf_dst->is_shared && buf_src->is_shared) {
|
||||
memcpy(dst->data, src->data, size);
|
||||
return true;
|
||||
}
|
||||
|
||||
// for private buffers, we need to use Metal blit commands
|
||||
@autoreleasepool {
|
||||
struct ggml_metal_buffer_id bid_src = ggml_metal_buffer_get_id(buf_src, src);
|
||||
struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf_dst, dst);
|
||||
|
||||
if (bid_src.metal == nil || bid_dst.metal == nil) {
|
||||
return false;
|
||||
}
|
||||
|
||||
id<MTLCommandBuffer> cmd_buf = [buf_dst->dev->mtl_queue commandBufferWithUnretainedReferences];
|
||||
|
||||
{
|
||||
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
||||
|
||||
[encoder copyFromBuffer:bid_src.metal
|
||||
sourceOffset:bid_src.offs
|
||||
toBuffer:bid_dst.metal
|
||||
destinationOffset:bid_dst.offs
|
||||
size:size];
|
||||
|
||||
[encoder endEncoding];
|
||||
}
|
||||
|
||||
[cmd_buf commit];
|
||||
[cmd_buf waitUntilCompleted];
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void ggml_metal_buffer_clear(ggml_metal_buffer_t buf, uint8_t value) {
|
||||
if (buf->is_shared) {
|
||||
memset(buf->all_data, value, buf->all_size);
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user