mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-16 16:27:32 +03:00
Compare commits
49 Commits
b7847
...
gg/compare
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5b01d8575d | ||
|
|
1488339138 | ||
|
|
4927795810 | ||
|
|
971facc38e | ||
|
|
d9a2a4bcaa | ||
|
|
dfd6106c84 | ||
|
|
bbada8bfb9 | ||
|
|
13f3ebfae1 | ||
|
|
dabaa2e77a | ||
|
|
2e916f996a | ||
|
|
f3bc98890c | ||
|
|
c3b87cebff | ||
|
|
0562503154 | ||
|
|
83bcdf7217 | ||
|
|
b316895ff9 | ||
|
|
ecbf01d441 | ||
|
|
1025fd2c09 | ||
|
|
c7358ddf64 | ||
|
|
d284baf1b5 | ||
|
|
bd90fc74c3 | ||
|
|
ce38a4db47 | ||
|
|
4fdbc1e4db | ||
|
|
7b7ae857f6 | ||
|
|
84b0a98319 | ||
|
|
b45ef2702c | ||
|
|
f3dd7b8e68 | ||
|
|
eed25bc6b0 | ||
|
|
b33df266d0 | ||
|
|
3bcc990997 | ||
|
|
d4964a7c66 | ||
|
|
50e8962f79 | ||
|
|
f6b533d898 | ||
|
|
72d3b1898a | ||
|
|
ebf5725870 | ||
|
|
0cd7032ca4 | ||
|
|
60368e1d73 | ||
|
|
88d23ad515 | ||
|
|
0a95026da9 | ||
|
|
b7feacf7f3 | ||
|
|
6ad70c5a77 | ||
|
|
631cbfcc7a | ||
|
|
2eee6c866c | ||
|
|
b931f81b5a | ||
|
|
c5c64f72ac | ||
|
|
eef375ce16 | ||
|
|
06961e2876 | ||
|
|
f2571df8b7 | ||
|
|
2b4cbd2834 | ||
|
|
68ac3acb43 |
8
.github/workflows/build.yml
vendored
8
.github/workflows/build.yml
vendored
@@ -21,7 +21,8 @@ on:
|
||||
'**/*.m',
|
||||
'**/*.metal',
|
||||
'**/*.comp',
|
||||
'**/*.glsl'
|
||||
'**/*.glsl',
|
||||
'**/*.wgsl'
|
||||
]
|
||||
|
||||
pull_request:
|
||||
@@ -42,7 +43,8 @@ on:
|
||||
'**/*.m',
|
||||
'**/*.metal',
|
||||
'**/*.comp',
|
||||
'**/*.glsl'
|
||||
'**/*.glsl',
|
||||
'**/*.wgsl'
|
||||
]
|
||||
|
||||
concurrency:
|
||||
@@ -1371,7 +1373,7 @@ jobs:
|
||||
id: update_presets
|
||||
if: ${{ matrix.build == 'arm64-snapdragon' }}
|
||||
run: |
|
||||
cp docs/backend/hexagon/CMakeUserPresets.json .
|
||||
cp docs/backend/snapdragon/CMakeUserPresets.json .
|
||||
|
||||
- name: Build
|
||||
id: ndk_build
|
||||
|
||||
13
.github/workflows/winget.yml
vendored
13
.github/workflows/winget.yml
vendored
@@ -28,16 +28,17 @@ jobs:
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
});
|
||||
console.log("Latest release:", releases[0].tag_name);
|
||||
return releases[0].tag_name;
|
||||
const { tag_name: version, assets: assets } = releases.find(({assets}) => assets.find(asset => asset.name.includes('win-vulkan')));
|
||||
const { browser_download_url: asset_url } = assets.find(asset => asset.name.includes('win-vulkan'));
|
||||
console.log("Latest release:", version);
|
||||
core.setOutput('VERSION', version);
|
||||
core.setOutput('ASSETURL', asset_url);
|
||||
|
||||
- name: Update manifest
|
||||
env:
|
||||
VERSION: ${{ steps.find_latest_release.outputs.result }}
|
||||
run: |
|
||||
echo "Updating manifest..."
|
||||
komac update --version ${{ env.VERSION }} \
|
||||
--urls "https://github.com/ggml-org/llama.cpp/releases/download/${{ env.VERSION }}/llama-${{ env.VERSION }}-bin-win-vulkan-x64.zip" \
|
||||
komac update --version ${{ steps.find_latest_release.outputs.VERSION }} \
|
||||
--urls "${{ steps.find_latest_release.outputs.ASSETURL }}" \
|
||||
--token ${{ secrets.WINGET_GITHUB_TOKEN }} \
|
||||
--submit \
|
||||
ggml.llamacpp
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
/common/jinja/ @ngxson @CISC @aldehir
|
||||
/common/llguidance.* @ggerganov
|
||||
/common/log.* @ggerganov
|
||||
/common/ngram-map.* @srogmann
|
||||
/common/peg-parser.* @aldehir
|
||||
/common/sampling.* @ggerganov
|
||||
/common/speculative.* @ggerganov
|
||||
@@ -67,6 +68,7 @@
|
||||
/ggml/src/ggml-rpc/ @rgerganov
|
||||
/ggml/src/ggml-threading.* @ggerganov
|
||||
/ggml/src/ggml-vulkan/ @0cc4m
|
||||
/ggml/src/ggml-virtgpu/ @kpouget
|
||||
/ggml/src/ggml-webgpu/ @reeselevine
|
||||
/ggml/src/ggml-zdnn/ @taronaeo @Andreas-Krebbel @AlekseiNikiforovIBM
|
||||
/ggml/src/ggml.c @ggerganov
|
||||
|
||||
@@ -213,6 +213,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [llama.vim](https://github.com/ggml-org/llama.vim) (MIT)
|
||||
- [LARS](https://github.com/abgulati/LARS) (AGPL)
|
||||
- [Llama Assistant](https://github.com/vietanhdev/llama-assistant) (GPL)
|
||||
- [LlamaLib](https://github.com/undreamai/LlamaLib) (Apache-2.0)
|
||||
- [LLMFarm](https://github.com/guinmoon/LLMFarm?tab=readme-ov-file) (MIT)
|
||||
- [LLMUnity](https://github.com/undreamai/LLMUnity) (MIT)
|
||||
- [LMStudio](https://lmstudio.ai/) (proprietary)
|
||||
|
||||
@@ -73,6 +73,10 @@ add_library(${TARGET} STATIC
|
||||
log.h
|
||||
ngram-cache.cpp
|
||||
ngram-cache.h
|
||||
ngram-map.cpp
|
||||
ngram-map.h
|
||||
ngram-mod.cpp
|
||||
ngram-mod.h
|
||||
peg-parser.cpp
|
||||
peg-parser.h
|
||||
preset.cpp
|
||||
|
||||
103
common/arg.cpp
103
common/arg.cpp
@@ -6,6 +6,7 @@
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
#include "speculative.h"
|
||||
#include "preset.h"
|
||||
|
||||
// fix problem with std::min and std::max
|
||||
@@ -579,14 +580,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
params.mmproj = res.mmproj;
|
||||
}
|
||||
// only download mmproj if the current example is using it
|
||||
for (auto & ex : mmproj_examples) {
|
||||
for (const auto & ex : mmproj_examples) {
|
||||
if (ctx_arg.ex == ex) {
|
||||
common_params_handle_model(params.mmproj, params.hf_token, params.offline);
|
||||
break;
|
||||
}
|
||||
}
|
||||
common_params_handle_model(params.speculative.model, params.hf_token, params.offline);
|
||||
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
|
||||
common_params_handle_model(params.speculative.mparams_dft, params.hf_token, params.offline);
|
||||
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
|
||||
}
|
||||
|
||||
// model is required (except for server)
|
||||
@@ -1216,16 +1217,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"-lcs", "--lookup-cache-static"}, "FNAME",
|
||||
"path to static lookup cache to use for lookup decoding (not updated by generation)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.lookup_cache_static = value;
|
||||
params.speculative.lookup_cache_static = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_LOOKUP}));
|
||||
).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"-lcd", "--lookup-cache-dynamic"}, "FNAME",
|
||||
"path to dynamic lookup cache to use for lookup decoding (updated by generation)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.lookup_cache_dynamic = value;
|
||||
params.speculative.lookup_cache_dynamic = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_LOOKUP}));
|
||||
).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"-c", "--ctx-size"}, "N",
|
||||
string_format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx),
|
||||
@@ -1295,11 +1296,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_env("LLAMA_ARG_CACHE_RAM").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
add_opt(common_arg(
|
||||
{"-kvu", "--kv-unified"},
|
||||
{"-no-kvu", "--no-kv-unified"},
|
||||
"use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)",
|
||||
[](common_params & params) {
|
||||
params.kv_unified = true;
|
||||
[](common_params & params, bool value) {
|
||||
params.kv_unified = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED}));
|
||||
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH}));
|
||||
add_opt(common_arg(
|
||||
{"--context-shift"},
|
||||
{"--no-context-shift"},
|
||||
@@ -2198,18 +2200,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
add_opt(common_arg(
|
||||
{"--mmap"},
|
||||
{"--no-mmap"},
|
||||
string_format("whether to memory-map model. Explicitly enabling mmap disables direct-io. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: %s)", params.use_mmap ? "enabled" : "disabled"),
|
||||
string_format("whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: %s)", params.use_mmap ? "enabled" : "disabled"),
|
||||
[](common_params & params, bool value) {
|
||||
params.use_mmap = value;
|
||||
if (value) {
|
||||
params.use_direct_io = false; // disable direct io when mmap is explicitly enabled
|
||||
}
|
||||
}
|
||||
).set_env("LLAMA_ARG_MMAP"));
|
||||
add_opt(common_arg(
|
||||
{"-dio", "--direct-io"},
|
||||
{"-ndio", "--no-direct-io"},
|
||||
string_format("use DirectIO if available. Takes precedence over --mmap (default: %s)", params.use_direct_io ? "enabled" : "disabled"),
|
||||
string_format("use DirectIO if available. (default: %s)", params.use_direct_io ? "enabled" : "disabled"),
|
||||
[](common_params & params, bool value) {
|
||||
params.use_direct_io = value;
|
||||
}
|
||||
@@ -2565,7 +2564,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"-hfd", "-hfrd", "--hf-repo-draft"}, "<user>/<model>[:quant]",
|
||||
"Same as --hf-repo, but for the draft model (default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.model.hf_repo = value;
|
||||
params.speculative.mparams_dft.hf_repo = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_HFD_REPO"));
|
||||
add_opt(common_arg(
|
||||
@@ -3386,7 +3385,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"-md", "--model-draft"}, "FNAME",
|
||||
"draft model for speculative decoding (default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.model.path = value;
|
||||
params.speculative.mparams_dft.path = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_MODEL_DRAFT"));
|
||||
add_opt(common_arg(
|
||||
@@ -3396,6 +3395,68 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.speculative.replacements.push_back({ tgt, dft });
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
|
||||
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
|
||||
common_speculative_type_to_str(params.speculative.type).c_str()),
|
||||
[](common_params & params, const std::string & value) {
|
||||
if (value == "none") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
|
||||
} else if (value == "ngram-cache") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
|
||||
} else if (value == "ngram-simple") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE;
|
||||
} else if (value == "ngram-map-k") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
|
||||
} else if (value == "ngram-map-k4v") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
|
||||
} else if (value == "ngram-mod") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
|
||||
} else {
|
||||
throw std::invalid_argument("unknown speculative decoding type without draft model");
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-size-n"}, "N",
|
||||
string_format("ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)", params.speculative.ngram_size_n),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1 || value > 1024) {
|
||||
throw std::invalid_argument("ngram size N must be between 1 and 1024 inclusive");
|
||||
}
|
||||
params.speculative.ngram_size_n = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-size-m"}, "N",
|
||||
string_format("ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)", params.speculative.ngram_size_m),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1 || value > 1024) {
|
||||
throw std::invalid_argument("ngram size M must be between 1 and 1024 inclusive");
|
||||
}
|
||||
params.speculative.ngram_size_m = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-check-rate"}, "N",
|
||||
string_format("ngram check rate for ngram-simple/ngram-map speculative decoding (default: %d)", params.speculative.ngram_check_rate),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1) {
|
||||
throw std::invalid_argument("ngram check rate must be at least 1");
|
||||
}
|
||||
params.speculative.ngram_check_rate = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-min-hits"}, "N",
|
||||
string_format("minimum hits for ngram-map speculative decoding (default: %d)", params.speculative.ngram_min_hits),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1) {
|
||||
throw std::invalid_argument("ngram min hits must be at least 1");
|
||||
}
|
||||
params.speculative.ngram_min_hits = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
|
||||
string_format(
|
||||
@@ -3622,8 +3683,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
|
||||
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
|
||||
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
|
||||
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
|
||||
params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
|
||||
params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
|
||||
params.port = 8012;
|
||||
params.n_ubatch = 1024;
|
||||
params.n_batch = 1024;
|
||||
@@ -3638,8 +3699,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF";
|
||||
params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf";
|
||||
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
|
||||
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
|
||||
params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
|
||||
params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
|
||||
params.port = 8012;
|
||||
params.n_ubatch = 1024;
|
||||
params.n_batch = 1024;
|
||||
|
||||
172
common/chat.cpp
172
common/chat.cpp
@@ -771,10 +771,12 @@ static std::string apply(
|
||||
|
||||
nlohmann::ordered_json inp = nlohmann::ordered_json{
|
||||
{"messages", messages_override.has_value() ? *messages_override : inputs.messages},
|
||||
{"tools", tools_override.has_value() ? *tools_override : inputs.tools},
|
||||
{"bos_token", tmpl.bos_token()},
|
||||
{"eos_token", tmpl.eos_token()},
|
||||
};
|
||||
if (tools_override.has_value() || !inputs.tools.empty()) {
|
||||
inp["tools"] = tools_override.has_value() ? *tools_override : inputs.tools;
|
||||
}
|
||||
if (inputs.extra_context.is_object()) {
|
||||
// TODO: do we need to merge, or replacing is fine?
|
||||
for (const auto & [k, v] : inputs.extra_context.items()) {
|
||||
@@ -790,9 +792,6 @@ static std::string apply(
|
||||
if (inputs.add_generation_prompt) {
|
||||
inp["add_generation_prompt"] = true;
|
||||
}
|
||||
if (inp["tools"].is_null()) {
|
||||
inp["tools"] = json::array();
|
||||
}
|
||||
|
||||
jinja::global_from_json(ctx, inp, inputs.mark_input);
|
||||
|
||||
@@ -2219,12 +2218,11 @@ static common_chat_params common_chat_params_init_glm_4_5(const common_chat_temp
|
||||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
LOG_DBG("%s\n", __func__);
|
||||
common_chat_params data;
|
||||
const std::optional<json> tools_override = json();
|
||||
const std::optional<json> additional_context = json {
|
||||
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
|
||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||
};
|
||||
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, tools_override, additional_context);
|
||||
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override =*/ std::nullopt, additional_context);
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
@@ -2573,20 +2571,165 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp
|
||||
static common_chat_params common_chat_params_init_solar_open(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
// TODO: Reasoning effort
|
||||
json additional_context = {};
|
||||
// Copy `reasoning_content` to `reasoning`
|
||||
auto adjusted_messages = json::array();
|
||||
for (const auto & msg : inputs.messages) {
|
||||
if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) {
|
||||
auto adjusted_message = msg;
|
||||
adjusted_message["reasoning"] = msg.at("reasoning_content");
|
||||
adjusted_message.erase("reasoning_content");
|
||||
adjusted_messages.push_back(adjusted_message);
|
||||
} else {
|
||||
adjusted_messages.push_back(msg);
|
||||
}
|
||||
}
|
||||
|
||||
data.prompt = apply(tmpl, inputs, std::nullopt, std::nullopt, additional_context);
|
||||
data.format = COMMON_CHAT_FORMAT_SOLAR_OPEN;
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto include_grammar = true;
|
||||
|
||||
auto prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
|
||||
// Check if we need to replace the flush token with end token during inference and without generation prompt.
|
||||
if (inputs.is_inference && !inputs.add_generation_prompt) {
|
||||
static constexpr std::string_view return_token = "<|flush|>";
|
||||
static constexpr std::string_view end_token = "<|end|>";
|
||||
if (size_t pos = prompt.rfind(return_token); pos != std::string::npos) {
|
||||
prompt.replace(pos, return_token.length(), end_token);
|
||||
}
|
||||
}
|
||||
|
||||
data.prompt = prompt;
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
"<|think|>",
|
||||
"<|content|>",
|
||||
"<|begin|>",
|
||||
"<|end|>",
|
||||
"<|tool_calls|>",
|
||||
"<|tool_call:begin|>",
|
||||
"<|tool_call:end|>",
|
||||
"<|tool_call:name|>",
|
||||
"<|tool_call:args|>",
|
||||
};
|
||||
|
||||
// TODO: Tool calling
|
||||
auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) {
|
||||
auto lit_think = p.atomic(p.literal("<|think|>"));
|
||||
auto lit_assistant_begin = p.atomic(p.literal("<|begin|>assistant"));
|
||||
auto lit_content = p.atomic(p.literal("<|content|>"));
|
||||
auto lit_end = p.atomic(p.literal("<|end|>"));
|
||||
auto parser_until_end = p.until("<|end|>");
|
||||
|
||||
// reasoning <- "<|think|>" (!"<|end|>" .)*
|
||||
auto parser_reasoning = p.rule("reasoning", lit_think + p.reasoning(parser_until_end));
|
||||
|
||||
// content <- "<|content|>" (!"<|end|>" .)*
|
||||
auto parser_content = p.rule("content", lit_content + p.content(parser_until_end));
|
||||
|
||||
// wrap_choice(items) <- item-choice wrapped*
|
||||
// item-choice <- items[0] / ... / items[n]
|
||||
// wrapped <- "<|end|><|begin|>assistant" item-choice
|
||||
auto wrap_choice = [&](const std::vector<common_peg_parser> & items) {
|
||||
auto choice = p.choice(items);
|
||||
return choice + p.zero_or_more(lit_end + lit_assistant_begin + choice);
|
||||
};
|
||||
|
||||
// wrap_seq(items) <- item[0] "<|end|><|begin|>assistant" item[1] ...
|
||||
auto wrap_seq = [&](const std::vector<common_peg_parser> & items) {
|
||||
auto seq = p.sequence();
|
||||
for (auto i = 0u; i < items.size(); i++) {
|
||||
if (i == 0) {
|
||||
seq += items[i];
|
||||
continue;
|
||||
}
|
||||
seq += lit_end + lit_assistant_begin + items[i];
|
||||
}
|
||||
return seq;
|
||||
};
|
||||
|
||||
// Response format parser
|
||||
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
|
||||
auto parser_response_format = lit_content + p.content(p.schema(p.json(), "response-format", inputs.json_schema));
|
||||
return p.choice({
|
||||
wrap_seq({parser_reasoning, parser_response_format}),
|
||||
wrap_seq({parser_response_format})
|
||||
});
|
||||
}
|
||||
|
||||
auto lit_tool_call_begin = p.literal("<|tool_call:begin|>");
|
||||
auto lit_tool_call_name = p.literal("<|tool_call:name|>");
|
||||
auto lit_tool_call_args = p.literal("<|tool_call:args|>");
|
||||
auto lit_tool_call_end = p.literal("<|tool_call:end|>");
|
||||
|
||||
// Tool call parser
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
auto parser_tool_call = p.choice();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
const auto & schema = function.at("parameters");
|
||||
|
||||
// tool(name, schema) <- name "<|tool_call:args|>" schema
|
||||
parser_tool_call |= p.rule("tool-" + name,
|
||||
p.atomic(p.tool_name(p.literal(name)) + lit_tool_call_args)
|
||||
+ p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)));
|
||||
});
|
||||
|
||||
auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0;
|
||||
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
|
||||
|
||||
// tool-calls <- "<|tool_calls|>" tool-call+
|
||||
// tool-call <- "<|tool_call:begin|> call-id "<|tool_call:name|>" &([^<]+ "<|tool_call:args|>") tool-choice "<|tool_call:end|>"
|
||||
// call-id <- [a-zA-Z0-9_-]+
|
||||
// tool-choice <- tool(t[0].name, t[0].schema) / ... / tool(t[n].name, t[n].schema)
|
||||
auto parser_tool_calls = p.trigger_rule("tool-calls",
|
||||
p.atomic(p.literal("<|tool_calls|>"))
|
||||
+ p.repeat(
|
||||
p.tool_open(
|
||||
lit_tool_call_begin
|
||||
+ p.tool_id(p.chars("[a-zA-Z0-9_-]", 1, -1))
|
||||
+ lit_tool_call_name
|
||||
+ p.peek(p.chars("[^<]", 1, -1) + lit_tool_call_args))
|
||||
+ parser_tool_call
|
||||
+ p.tool_close(lit_tool_call_end),
|
||||
/* min = */ 1,
|
||||
/* max = */ max_calls));
|
||||
|
||||
if (min_calls == 1) {
|
||||
// If required, then try any combination of the reasoning, content, and tool call
|
||||
return p.choice({
|
||||
wrap_seq({parser_reasoning, parser_content, parser_tool_calls}),
|
||||
wrap_seq({parser_reasoning, parser_tool_calls}),
|
||||
wrap_seq({parser_content, parser_tool_calls}),
|
||||
wrap_seq({parser_tool_calls})
|
||||
});
|
||||
}
|
||||
|
||||
return wrap_choice({parser_reasoning, parser_content, parser_tool_calls});
|
||||
}
|
||||
|
||||
// Content only parser
|
||||
include_grammar = false;
|
||||
return wrap_choice({parser_reasoning, parser_content});
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
auto schema = function.at("parameters");
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
data.grammar_triggers = {
|
||||
{COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls|>"}
|
||||
};
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
@@ -3043,6 +3186,13 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
return common_chat_params_init_apriel_1_5(tmpl, params);
|
||||
}
|
||||
|
||||
// Solar Open
|
||||
if (src.find("<|tool_response:begin|>") != std::string::npos &&
|
||||
src.find("<|tool_response:name|>") != std::string::npos &&
|
||||
src.find("<|tool_response:result|>") != std::string::npos) {
|
||||
return common_chat_params_init_solar_open(tmpl, params);
|
||||
}
|
||||
|
||||
// Use generic handler when mixing tools + JSON schema.
|
||||
// TODO: support that mix in handlers below.
|
||||
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
||||
|
||||
@@ -1097,7 +1097,10 @@ common_init_result::common_init_result(common_params & params) :
|
||||
if (params.fit_params) {
|
||||
LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
|
||||
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
|
||||
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx,
|
||||
params.tensor_split,
|
||||
params.tensor_buft_overrides.data(),
|
||||
params.fit_params_target.data(),
|
||||
params.fit_params_min_ctx,
|
||||
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||
}
|
||||
|
||||
@@ -1208,10 +1211,6 @@ std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
|
||||
return pimpl->lora;
|
||||
}
|
||||
|
||||
void common_init_result::free_context() {
|
||||
pimpl->context.reset();
|
||||
}
|
||||
|
||||
common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
common_init_result_ptr res(new common_init_result(params));
|
||||
|
||||
|
||||
@@ -164,6 +164,17 @@ enum common_params_sampling_config : uint64_t {
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
|
||||
};
|
||||
|
||||
enum common_speculative_type {
|
||||
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
|
||||
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MOD,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, // self-speculative decoding with 3-level n-gram cache
|
||||
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
|
||||
};
|
||||
|
||||
// sampling parameters
|
||||
struct common_params_sampling {
|
||||
@@ -242,17 +253,40 @@ struct common_params_model {
|
||||
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
|
||||
};
|
||||
|
||||
struct common_params_speculative {
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
struct common_ngram_mod;
|
||||
|
||||
int32_t n_ctx = 0; // draft context size
|
||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
struct common_params_speculative {
|
||||
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding
|
||||
|
||||
// general-purpose speculative decoding parameters
|
||||
|
||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||
|
||||
// ngram-based speculative decoding
|
||||
|
||||
uint16_t ngram_size_n = 12; // ngram size for lookup
|
||||
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
|
||||
uint16_t ngram_check_rate = 1; // check rate for ngram lookup
|
||||
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
|
||||
|
||||
std::shared_ptr<common_ngram_mod> ngram_mod;
|
||||
|
||||
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT
|
||||
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
||||
|
||||
// draft-model speculative decoding
|
||||
|
||||
struct common_params_model mparams_dft;
|
||||
|
||||
llama_model * model_dft = nullptr; // a llama_model that can be shared by multiple speculative contexts
|
||||
|
||||
llama_context_params cparams_dft; // these are the parameters for the draft llama_context
|
||||
|
||||
int32_t n_ctx = 0; // draft context size
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||
|
||||
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
||||
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
||||
@@ -260,7 +294,14 @@ struct common_params_speculative {
|
||||
struct cpu_params cpuparams;
|
||||
struct cpu_params cpuparams_batch;
|
||||
|
||||
struct common_params_model model;
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
|
||||
bool has_dft() const {
|
||||
return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty();
|
||||
}
|
||||
};
|
||||
|
||||
struct common_params_vocoder {
|
||||
@@ -378,8 +419,6 @@ struct common_params {
|
||||
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT
|
||||
std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
|
||||
std::string input_suffix = ""; // string to suffix user inputs with // NOLINT
|
||||
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
|
||||
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
||||
std::string logits_file = ""; // file for saving *all* logits // NOLINT
|
||||
|
||||
// llama-debug specific options
|
||||
@@ -438,7 +477,7 @@ struct common_params {
|
||||
|
||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||
bool use_mmap = true; // enable mmap to use filesystem cache
|
||||
bool use_direct_io = true; // read from disk without buffering for faster model loading
|
||||
bool use_direct_io = false; // read from disk without buffering
|
||||
bool use_mlock = false; // use mlock to keep model in memory
|
||||
bool verbose_prompt = false; // print prompt tokens before generation
|
||||
bool display_prompt = true; // print prompt before generation
|
||||
@@ -575,10 +614,6 @@ struct common_params {
|
||||
// return false from callback to abort model loading or true to continue
|
||||
llama_progress_callback load_progress_callback = NULL;
|
||||
void * load_progress_callback_user_data = NULL;
|
||||
|
||||
bool has_speculative() const {
|
||||
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
|
||||
}
|
||||
};
|
||||
|
||||
// call once at the start of a program if it uses libcommon
|
||||
@@ -714,8 +749,6 @@ struct common_init_result {
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> & lora();
|
||||
|
||||
void free_context();
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
|
||||
@@ -44,6 +44,12 @@ static std::string get_line_col(const std::string & source, size_t pos) {
|
||||
return "line " + std::to_string(line) + ", column " + std::to_string(col);
|
||||
}
|
||||
|
||||
static void ensure_key_type_allowed(const value & val) {
|
||||
if (!val->is_hashable()) {
|
||||
throw std::runtime_error("Type: " + val->type() + " is not allowed as object key");
|
||||
}
|
||||
}
|
||||
|
||||
// execute with error handling
|
||||
value statement::execute(context & ctx) {
|
||||
try {
|
||||
@@ -95,20 +101,10 @@ value identifier::execute_impl(context & ctx) {
|
||||
value object_literal::execute_impl(context & ctx) {
|
||||
auto obj = mk_val<value_object>();
|
||||
for (const auto & pair : val) {
|
||||
value key_val = pair.first->execute(ctx);
|
||||
if (!is_val<value_string>(key_val) && !is_val<value_int>(key_val)) {
|
||||
throw std::runtime_error("Object literal: keys must be string or int values, got " + key_val->type());
|
||||
}
|
||||
std::string key = key_val->as_string().str();
|
||||
value key = pair.first->execute(ctx);
|
||||
value val = pair.second->execute(ctx);
|
||||
JJ_DEBUG("Object literal: setting key '%s' with value type %s", key.c_str(), val->type().c_str());
|
||||
JJ_DEBUG("Object literal: setting key '%s' with value type %s", key->as_string().str().c_str(), val->type().c_str());
|
||||
obj->insert(key, val);
|
||||
|
||||
if (is_val<value_int>(key_val)) {
|
||||
obj->val_obj.is_key_numeric = true;
|
||||
} else if (obj->val_obj.is_key_numeric) {
|
||||
throw std::runtime_error("Object literal: cannot mix numeric and non-numeric keys");
|
||||
}
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
@@ -127,9 +123,9 @@ value binary_expression::execute_impl(context & ctx) {
|
||||
value right_val = right->execute(ctx);
|
||||
JJ_DEBUG("Executing binary expression %s '%s' %s", left_val->type().c_str(), op.value.c_str(), right_val->type().c_str());
|
||||
if (op.value == "==") {
|
||||
return mk_val<value_bool>(value_compare(left_val, right_val, value_compare_op::eq));
|
||||
return mk_val<value_bool>(*left_val == *right_val);
|
||||
} else if (op.value == "!=") {
|
||||
return mk_val<value_bool>(!value_compare(left_val, right_val, value_compare_op::eq));
|
||||
return mk_val<value_bool>(!(*left_val == *right_val));
|
||||
}
|
||||
|
||||
auto workaround_concat_null_with_str = [&](value & res) -> bool {
|
||||
@@ -230,7 +226,7 @@ value binary_expression::execute_impl(context & ctx) {
|
||||
auto & arr = right_val->as_array();
|
||||
bool member = false;
|
||||
for (const auto & item : arr) {
|
||||
if (value_compare(left_val, item, value_compare_op::eq)) {
|
||||
if (*left_val == *item) {
|
||||
member = true;
|
||||
break;
|
||||
}
|
||||
@@ -265,10 +261,9 @@ value binary_expression::execute_impl(context & ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
// String in object
|
||||
if (is_val<value_string>(left_val) && is_val<value_object>(right_val)) {
|
||||
auto key = left_val->as_string().str();
|
||||
bool has_key = right_val->has_key(key);
|
||||
// Value key in object
|
||||
if (is_val<value_object>(right_val)) {
|
||||
bool has_key = right_val->has_key(left_val);
|
||||
if (op.value == "in") {
|
||||
return mk_val<value_bool>(has_key);
|
||||
} else if (op.value == "not in") {
|
||||
@@ -465,14 +460,8 @@ value for_statement::execute_impl(context & ctx) {
|
||||
JJ_DEBUG("%s", "For loop over object keys");
|
||||
auto & obj = iterable_val->as_ordered_object();
|
||||
for (auto & p : obj) {
|
||||
auto tuple = mk_val<value_array>();
|
||||
if (iterable_val->val_obj.is_key_numeric) {
|
||||
tuple->push_back(mk_val<value_int>(std::stoll(p.first)));
|
||||
} else {
|
||||
tuple->push_back(mk_val<value_string>(p.first));
|
||||
}
|
||||
tuple->push_back(p.second);
|
||||
items.push_back(tuple);
|
||||
auto tuple = mk_val<value_tuple>(p);
|
||||
items.push_back(std::move(tuple));
|
||||
}
|
||||
if (ctx.is_get_stats) {
|
||||
iterable_val->stats.used = true;
|
||||
@@ -602,11 +591,13 @@ value set_statement::execute_impl(context & ctx) {
|
||||
auto rhs = val ? val->execute(ctx) : exec_statements(body, ctx);
|
||||
|
||||
if (is_stmt<identifier>(assignee)) {
|
||||
// case: {% set my_var = value %}
|
||||
auto var_name = cast_stmt<identifier>(assignee)->val;
|
||||
JJ_DEBUG("Setting global variable '%s' with value type %s", var_name.c_str(), rhs->type().c_str());
|
||||
ctx.set_val(var_name, rhs);
|
||||
|
||||
} else if (is_stmt<tuple_literal>(assignee)) {
|
||||
// case: {% set a, b = value %}
|
||||
auto tuple = cast_stmt<tuple_literal>(assignee);
|
||||
if (!is_val<value_array>(rhs)) {
|
||||
throw std::runtime_error("Cannot unpack non-iterable type in set: " + rhs->type());
|
||||
@@ -625,6 +616,7 @@ value set_statement::execute_impl(context & ctx) {
|
||||
}
|
||||
|
||||
} else if (is_stmt<member_expression>(assignee)) {
|
||||
// case: {% set ns.my_var = value %}
|
||||
auto member = cast_stmt<member_expression>(assignee);
|
||||
if (member->computed) {
|
||||
throw std::runtime_error("Cannot assign to computed member");
|
||||
@@ -767,22 +759,22 @@ value member_expression::execute_impl(context & ctx) {
|
||||
}
|
||||
|
||||
JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str());
|
||||
ensure_key_type_allowed(property);
|
||||
|
||||
value val = mk_val<value_undefined>("object_property");
|
||||
|
||||
if (is_val<value_undefined>(object)) {
|
||||
JJ_DEBUG("%s", "Accessing property on undefined object, returning undefined");
|
||||
return val;
|
||||
|
||||
} else if (is_val<value_object>(object)) {
|
||||
if (!is_val<value_string>(property)) {
|
||||
throw std::runtime_error("Cannot access object with non-string: got " + property->type());
|
||||
}
|
||||
auto key = property->as_string().str();
|
||||
val = object->at(key, val);
|
||||
val = object->at(property, val);
|
||||
if (is_val<value_undefined>(val)) {
|
||||
val = try_builtin_func(ctx, key, object, true);
|
||||
}
|
||||
JJ_DEBUG("Accessed property '%s' value, got type: %s", key.c_str(), val->type().c_str());
|
||||
|
||||
} else if (is_val<value_array>(object) || is_val<value_string>(object)) {
|
||||
if (is_val<value_int>(property)) {
|
||||
int64_t index = property->as_int();
|
||||
@@ -806,6 +798,7 @@ value member_expression::execute_impl(context & ctx) {
|
||||
auto key = property->as_string().str();
|
||||
JJ_DEBUG("Accessing %s built-in '%s'", is_val<value_array>(object) ? "array" : "string", key.c_str());
|
||||
val = try_builtin_func(ctx, key, object, true);
|
||||
|
||||
} else {
|
||||
throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type());
|
||||
}
|
||||
|
||||
@@ -79,18 +79,18 @@ struct context {
|
||||
}
|
||||
|
||||
value get_val(const std::string & name) {
|
||||
auto it = env->val_obj.unordered.find(name);
|
||||
if (it != env->val_obj.unordered.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
return mk_val<value_undefined>(name);
|
||||
}
|
||||
value default_val = mk_val<value_undefined>(name);
|
||||
return env->at(name, default_val);
|
||||
}
|
||||
|
||||
void set_val(const std::string & name, const value & val) {
|
||||
env->insert(name, val);
|
||||
}
|
||||
|
||||
void set_val(const value & name, const value & val) {
|
||||
env->insert(name, val);
|
||||
}
|
||||
|
||||
void print_vars() const {
|
||||
printf("Context Variables:\n%s\n", value_to_json(env, 2).c_str());
|
||||
}
|
||||
@@ -344,9 +344,19 @@ struct array_literal : public expression {
|
||||
}
|
||||
};
|
||||
|
||||
struct tuple_literal : public array_literal {
|
||||
explicit tuple_literal(statements && val) : array_literal(std::move(val)) {}
|
||||
struct tuple_literal : public expression {
|
||||
statements val;
|
||||
explicit tuple_literal(statements && val) : val(std::move(val)) {
|
||||
for (const auto& item : this->val) chk_type<expression>(item);
|
||||
}
|
||||
std::string type() const override { return "TupleLiteral"; }
|
||||
value execute_impl(context & ctx) override {
|
||||
auto arr = mk_val<value_array>();
|
||||
for (const auto & item_stmt : val) {
|
||||
arr->push_back(item_stmt->execute(ctx));
|
||||
}
|
||||
return mk_val<value_tuple>(std::move(arr->as_array()));
|
||||
}
|
||||
};
|
||||
|
||||
struct object_literal : public expression {
|
||||
|
||||
@@ -61,6 +61,12 @@ size_t string::length() const {
|
||||
return len;
|
||||
}
|
||||
|
||||
void string::hash_update(hasher & hash) const noexcept {
|
||||
for (const auto & part : parts) {
|
||||
hash.update(part.val.data(), part.val.length());
|
||||
}
|
||||
}
|
||||
|
||||
bool string::all_parts_are_input() const {
|
||||
for (const auto & part : parts) {
|
||||
if (!part.is_input) {
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
namespace jinja {
|
||||
|
||||
// allow differentiate between user input strings and template strings
|
||||
@@ -37,6 +39,7 @@ struct string {
|
||||
|
||||
std::string str() const;
|
||||
size_t length() const;
|
||||
void hash_update(hasher & hash) const noexcept;
|
||||
bool all_parts_are_input() const;
|
||||
bool is_uppercase() const;
|
||||
bool is_lowercase() const;
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
|
||||
namespace jinja {
|
||||
|
||||
@@ -46,4 +48,102 @@ static std::string fmt_error_with_source(const std::string & tag, const std::str
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
// Note: this is a simple hasher, not cryptographically secure, just for hash table usage
|
||||
struct hasher {
|
||||
static constexpr auto size_t_digits = sizeof(size_t) * 8;
|
||||
static constexpr size_t prime = size_t_digits == 64 ? 0x100000001b3 : 0x01000193;
|
||||
static constexpr size_t seed = size_t_digits == 64 ? 0xcbf29ce484222325 : 0x811c9dc5;
|
||||
static constexpr auto block_size = sizeof(size_t); // in bytes; allowing the compiler to vectorize the computation
|
||||
|
||||
static_assert(size_t_digits == 64 || size_t_digits == 32);
|
||||
static_assert(block_size == 8 || block_size == 4);
|
||||
|
||||
uint8_t buffer[block_size];
|
||||
size_t idx = 0; // current index in buffer
|
||||
size_t state = seed;
|
||||
|
||||
hasher() = default;
|
||||
hasher(const std::type_info & type_inf) noexcept {
|
||||
const auto type_hash = type_inf.hash_code();
|
||||
update(&type_hash, sizeof(type_hash));
|
||||
}
|
||||
|
||||
// Properties:
|
||||
// - update is not associative: update(a).update(b) != update(b).update(a)
|
||||
// - update(a ~ b) == update(a).update(b) with ~ as concatenation operator --> useful for streaming
|
||||
// - update("", 0) --> state unchanged with empty input
|
||||
hasher& update(void const * bytes, size_t len) noexcept {
|
||||
const uint8_t * c = static_cast<uint8_t const *>(bytes);
|
||||
if (len == 0) {
|
||||
return *this;
|
||||
}
|
||||
size_t processed = 0;
|
||||
|
||||
// first, fill the existing buffer if it's partial
|
||||
if (idx > 0) {
|
||||
size_t to_fill = block_size - idx;
|
||||
if (to_fill > len) {
|
||||
to_fill = len;
|
||||
}
|
||||
std::memcpy(buffer + idx, c, to_fill);
|
||||
idx += to_fill;
|
||||
processed += to_fill;
|
||||
if (idx == block_size) {
|
||||
update_block(buffer);
|
||||
idx = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// process full blocks from the remaining input
|
||||
for (; processed + block_size <= len; processed += block_size) {
|
||||
update_block(c + processed);
|
||||
}
|
||||
|
||||
// buffer any remaining bytes
|
||||
size_t remaining = len - processed;
|
||||
if (remaining > 0) {
|
||||
std::memcpy(buffer, c + processed, remaining);
|
||||
idx = remaining;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// convenience function for testing only
|
||||
hasher& update(const std::string & s) noexcept {
|
||||
return update(s.data(), s.size());
|
||||
}
|
||||
|
||||
// finalize and get the hash value
|
||||
// note: after calling digest, the hasher state is modified, do not call update() again
|
||||
size_t digest() noexcept {
|
||||
// if there are remaining bytes in buffer, fill the rest with zeros and process
|
||||
if (idx > 0) {
|
||||
for (size_t i = idx; i < block_size; ++i) {
|
||||
buffer[i] = 0;
|
||||
}
|
||||
update_block(buffer);
|
||||
idx = 0;
|
||||
}
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
private:
|
||||
// IMPORTANT: block must have at least block_size bytes
|
||||
void update_block(const uint8_t * block) noexcept {
|
||||
size_t blk = static_cast<uint32_t>(block[0])
|
||||
| (static_cast<uint32_t>(block[1]) << 8)
|
||||
| (static_cast<uint32_t>(block[2]) << 16)
|
||||
| (static_cast<uint32_t>(block[3]) << 24);
|
||||
if constexpr (block_size == 8) {
|
||||
blk = blk | (static_cast<uint64_t>(block[4]) << 32)
|
||||
| (static_cast<uint64_t>(block[5]) << 40)
|
||||
| (static_cast<uint64_t>(block[6]) << 48)
|
||||
| (static_cast<uint64_t>(block[7]) << 56);
|
||||
}
|
||||
state ^= blk;
|
||||
state *= prime;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace jinja
|
||||
|
||||
@@ -114,6 +114,18 @@ static T slice(const T & array, int64_t start, int64_t stop, int64_t step = 1) {
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static value empty_value_fn(const func_args &) {
|
||||
if constexpr (std::is_same_v<T, value_int>) {
|
||||
return mk_val<T>(0);
|
||||
} else if constexpr (std::is_same_v<T, value_float>) {
|
||||
return mk_val<T>(0.0);
|
||||
} else if constexpr (std::is_same_v<T, value_bool>) {
|
||||
return mk_val<T>(false);
|
||||
} else {
|
||||
return mk_val<T>();
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
static value test_type_fn(const func_args & args) {
|
||||
args.ensure_count(1);
|
||||
@@ -128,6 +140,13 @@ static value test_type_fn(const func_args & args) {
|
||||
JJ_DEBUG("test_type_fn: type=%s or %s result=%d", typeid(T).name(), typeid(U).name(), is_type ? 1 : 0);
|
||||
return mk_val<value_bool>(is_type);
|
||||
}
|
||||
template<typename T, typename U, typename V>
|
||||
static value test_type_fn(const func_args & args) {
|
||||
args.ensure_count(1);
|
||||
bool is_type = is_val<T>(args.get_pos(0)) || is_val<U>(args.get_pos(0)) || is_val<V>(args.get_pos(0));
|
||||
JJ_DEBUG("test_type_fn: type=%s, %s or %s result=%d", typeid(T).name(), typeid(U).name(), typeid(V).name(), is_type ? 1 : 0);
|
||||
return mk_val<value_bool>(is_type);
|
||||
}
|
||||
template<value_compare_op op>
|
||||
static value test_compare_fn(const func_args & args) {
|
||||
args.ensure_count(2, 2);
|
||||
@@ -163,7 +182,7 @@ static value selectattr(const func_args & args) {
|
||||
args.ensure_vals<value_array, value_string, value_string, value_string>(true, true, false, false);
|
||||
|
||||
auto arr = args.get_pos(0)->as_array();
|
||||
auto attr_name = args.get_pos(1)->as_string().str();
|
||||
auto attribute = args.get_pos(1);
|
||||
auto out = mk_val<value_array>();
|
||||
value val_default = mk_val<value_undefined>();
|
||||
|
||||
@@ -173,7 +192,7 @@ static value selectattr(const func_args & args) {
|
||||
if (!is_val<value_object>(item)) {
|
||||
throw raised_exception("selectattr: item is not an object");
|
||||
}
|
||||
value attr_val = item->at(attr_name, val_default);
|
||||
value attr_val = item->at(attribute, val_default);
|
||||
bool is_selected = attr_val->as_bool();
|
||||
if constexpr (is_reject) is_selected = !is_selected;
|
||||
if (is_selected) out->push_back(item);
|
||||
@@ -217,7 +236,7 @@ static value selectattr(const func_args & args) {
|
||||
if (!is_val<value_object>(item)) {
|
||||
throw raised_exception("selectattr: item is not an object");
|
||||
}
|
||||
value attr_val = item->at(attr_name, val_default);
|
||||
value attr_val = item->at(attribute, val_default);
|
||||
func_args test_args(args.ctx);
|
||||
test_args.push_back(attr_val); // attribute value
|
||||
test_args.push_back(extra_arg); // extra argument
|
||||
@@ -347,8 +366,8 @@ const func_builtins & global_builtins() {
|
||||
{"test_is_integer", test_type_fn<value_int>},
|
||||
{"test_is_float", test_type_fn<value_float>},
|
||||
{"test_is_number", test_type_fn<value_int, value_float>},
|
||||
{"test_is_iterable", test_type_fn<value_array, value_string>},
|
||||
{"test_is_sequence", test_type_fn<value_array, value_string>},
|
||||
{"test_is_iterable", test_type_fn<value_array, value_string, value_undefined>},
|
||||
{"test_is_sequence", test_type_fn<value_array, value_string, value_undefined>},
|
||||
{"test_is_mapping", test_type_fn<value_object>},
|
||||
{"test_is_lower", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_string>();
|
||||
@@ -741,6 +760,7 @@ const func_builtins & value_array_t::get_builtins() const {
|
||||
args.ensure_count(1, 4);
|
||||
args.ensure_vals<value_array, value_int, value_int, value_int>(true, true, false, false);
|
||||
|
||||
auto val = args.get_pos(0);
|
||||
auto arg0 = args.get_pos(1);
|
||||
auto arg1 = args.get_pos(2, mk_val<value_undefined>());
|
||||
auto arg2 = args.get_pos(3, mk_val<value_undefined>());
|
||||
@@ -762,10 +782,8 @@ const func_builtins & value_array_t::get_builtins() const {
|
||||
if (step == 0) {
|
||||
throw raised_exception("slice step cannot be zero");
|
||||
}
|
||||
auto arr = slice(args.get_pos(0)->as_array(), start, stop, step);
|
||||
auto res = mk_val<value_array>();
|
||||
res->val_arr = std::move(arr);
|
||||
return res;
|
||||
auto arr = slice(val->as_array(), start, stop, step);
|
||||
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
|
||||
}},
|
||||
{"selectattr", selectattr<false>},
|
||||
{"select", selectattr<false>},
|
||||
@@ -785,15 +803,14 @@ const func_builtins & value_array_t::get_builtins() const {
|
||||
}
|
||||
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
|
||||
const std::string delim = val_delim->is_undefined() ? "" : val_delim->as_string().str();
|
||||
const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str();
|
||||
std::string result;
|
||||
for (size_t i = 0; i < arr.size(); ++i) {
|
||||
value val_arr = arr[i];
|
||||
if (!attribute->is_undefined()) {
|
||||
if (attr_is_int && is_val<value_array>(val_arr)) {
|
||||
val_arr = val_arr->at(attr_int);
|
||||
} else if (!attr_is_int && !attr_name.empty() && is_val<value_object>(val_arr)) {
|
||||
val_arr = val_arr->at(attr_name);
|
||||
} else if (!attr_is_int && is_val<value_object>(val_arr)) {
|
||||
val_arr = val_arr->at(attribute);
|
||||
}
|
||||
}
|
||||
if (!is_val<value_string>(val_arr) && !is_val<value_int>(val_arr) && !is_val<value_float>(val_arr)) {
|
||||
@@ -808,9 +825,7 @@ const func_builtins & value_array_t::get_builtins() const {
|
||||
}},
|
||||
{"string", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_array>();
|
||||
auto str = mk_val<value_string>();
|
||||
gather_string_parts_recursive(args.get_pos(0), str);
|
||||
return str;
|
||||
return mk_val<value_string>(args.get_pos(0)->as_string());
|
||||
}},
|
||||
{"tojson", tojson},
|
||||
{"map", [](const func_args & args) -> value {
|
||||
@@ -821,26 +836,26 @@ const func_builtins & value_array_t::get_builtins() const {
|
||||
if (!is_val<value_kwarg>(args.get_args().at(1))) {
|
||||
throw not_implemented_exception("map: filter-mapping not implemented");
|
||||
}
|
||||
value val = args.get_pos(0);
|
||||
value attribute = args.get_kwarg_or_pos("attribute", 1);
|
||||
const bool attr_is_int = is_val<value_int>(attribute);
|
||||
if (!is_val<value_string>(attribute) && !attr_is_int) {
|
||||
throw raised_exception("map: attribute must be string or integer");
|
||||
}
|
||||
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
|
||||
const std::string attr_name = attribute->as_string().str();
|
||||
value default_val = args.get_kwarg("default", mk_val<value_undefined>());
|
||||
auto out = mk_val<value_array>();
|
||||
auto arr = args.get_pos(0)->as_array();
|
||||
auto arr = val->as_array();
|
||||
for (const auto & item : arr) {
|
||||
value attr_val;
|
||||
if (attr_is_int) {
|
||||
attr_val = is_val<value_array>(item) ? item->at(attr_int, default_val) : default_val;
|
||||
} else {
|
||||
attr_val = is_val<value_object>(item) ? item->at(attr_name, default_val) : default_val;
|
||||
attr_val = is_val<value_object>(item) ? item->at(attribute, default_val) : default_val;
|
||||
}
|
||||
out->push_back(attr_val);
|
||||
}
|
||||
return out;
|
||||
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(out->as_array())) : out;
|
||||
}},
|
||||
{"append", [](const func_args & args) -> value {
|
||||
args.ensure_count(2);
|
||||
@@ -867,6 +882,7 @@ const func_builtins & value_array_t::get_builtins() const {
|
||||
if (!is_val<value_array>(args.get_pos(0))) {
|
||||
throw raised_exception("sort: first argument must be an array");
|
||||
}
|
||||
value val = args.get_pos(0);
|
||||
value val_reverse = args.get_kwarg_or_pos("reverse", 1);
|
||||
value val_case = args.get_kwarg_or_pos("case_sensitive", 2);
|
||||
value attribute = args.get_kwarg_or_pos("attribute", 3);
|
||||
@@ -875,8 +891,7 @@ const func_builtins & value_array_t::get_builtins() const {
|
||||
const bool reverse = val_reverse->as_bool(); // undefined == false
|
||||
const bool attr_is_int = is_val<value_int>(attribute);
|
||||
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
|
||||
const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str();
|
||||
std::vector<value> arr = cast_val<value_array>(args.get_pos(0))->as_array(); // copy
|
||||
std::vector<value> arr = val->as_array(); // copy
|
||||
std::sort(arr.begin(), arr.end(),[&](const value & a, const value & b) {
|
||||
value val_a = a;
|
||||
value val_b = b;
|
||||
@@ -884,22 +899,23 @@ const func_builtins & value_array_t::get_builtins() const {
|
||||
if (attr_is_int && is_val<value_array>(a) && is_val<value_array>(b)) {
|
||||
val_a = a->at(attr_int);
|
||||
val_b = b->at(attr_int);
|
||||
} else if (!attr_is_int && !attr_name.empty() && is_val<value_object>(a) && is_val<value_object>(b)) {
|
||||
val_a = a->at(attr_name);
|
||||
val_b = b->at(attr_name);
|
||||
} else if (!attr_is_int && is_val<value_object>(a) && is_val<value_object>(b)) {
|
||||
val_a = a->at(attribute);
|
||||
val_b = b->at(attribute);
|
||||
} else {
|
||||
throw raised_exception("sort: unsupported object attribute comparison");
|
||||
throw raised_exception("sort: unsupported object attribute comparison between " + a->type() + " and " + b->type());
|
||||
}
|
||||
}
|
||||
return value_compare(val_a, val_b, reverse ? value_compare_op::gt : value_compare_op::lt);
|
||||
});
|
||||
return mk_val<value_array>(arr);
|
||||
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
|
||||
}},
|
||||
{"reverse", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_array>();
|
||||
std::vector<value> arr = cast_val<value_array>(args.get_pos(0))->as_array(); // copy
|
||||
value val = args.get_pos(0);
|
||||
std::vector<value> arr = val->as_array(); // copy
|
||||
std::reverse(arr.begin(), arr.end());
|
||||
return mk_val<value_array>(arr);
|
||||
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
|
||||
}},
|
||||
{"unique", [](const func_args &) -> value {
|
||||
throw not_implemented_exception("Array unique builtin not implemented");
|
||||
@@ -930,7 +946,7 @@ const func_builtins & value_object_t::get_builtins() const {
|
||||
default_val = args.get_pos(2);
|
||||
}
|
||||
const value obj = args.get_pos(0);
|
||||
std::string key = args.get_pos(1)->as_string().str();
|
||||
const value key = args.get_pos(1);
|
||||
return obj->at(key, default_val);
|
||||
}},
|
||||
{"keys", [](const func_args & args) -> value {
|
||||
@@ -938,7 +954,7 @@ const func_builtins & value_object_t::get_builtins() const {
|
||||
const auto & obj = args.get_pos(0)->as_ordered_object();
|
||||
auto result = mk_val<value_array>();
|
||||
for (const auto & pair : obj) {
|
||||
result->push_back(mk_val<value_string>(pair.first));
|
||||
result->push_back(pair.first);
|
||||
}
|
||||
return result;
|
||||
}},
|
||||
@@ -956,15 +972,16 @@ const func_builtins & value_object_t::get_builtins() const {
|
||||
const auto & obj = args.get_pos(0)->as_ordered_object();
|
||||
auto result = mk_val<value_array>();
|
||||
for (const auto & pair : obj) {
|
||||
auto item = mk_val<value_array>();
|
||||
item->push_back(mk_val<value_string>(pair.first));
|
||||
item->push_back(pair.second);
|
||||
auto item = mk_val<value_tuple>(pair);
|
||||
result->push_back(std::move(item));
|
||||
}
|
||||
return result;
|
||||
}},
|
||||
{"tojson", tojson},
|
||||
{"string", tojson},
|
||||
{"string", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_object>();
|
||||
return mk_val<value_string>(args.get_pos(0)->as_string());
|
||||
}},
|
||||
{"length", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_object>();
|
||||
const auto & obj = args.get_pos(0)->as_ordered_object();
|
||||
@@ -985,11 +1002,11 @@ const func_builtins & value_object_t::get_builtins() const {
|
||||
const bool reverse = val_reverse->as_bool(); // undefined == false
|
||||
const bool by_value = is_val<value_string>(val_by) && val_by->as_string().str() == "value" ? true : false;
|
||||
auto result = mk_val<value_object>(val_input); // copy
|
||||
std::sort(result->val_obj.ordered.begin(), result->val_obj.ordered.end(), [&](const auto & a, const auto & b) {
|
||||
std::sort(result->val_obj.begin(), result->val_obj.end(), [&](const auto & a, const auto & b) {
|
||||
if (by_value) {
|
||||
return value_compare(a.second, b.second, reverse ? value_compare_op::gt : value_compare_op::lt);
|
||||
} else {
|
||||
return reverse ? a.first > b.first : a.first < b.first;
|
||||
return value_compare(a.first, b.first, reverse ? value_compare_op::gt : value_compare_op::lt);
|
||||
}
|
||||
});
|
||||
return result;
|
||||
@@ -1005,7 +1022,22 @@ const func_builtins & value_none_t::get_builtins() const {
|
||||
static const func_builtins builtins = {
|
||||
{"default", default_value},
|
||||
{"tojson", tojson},
|
||||
{"string", [](const func_args &) -> value { return mk_val<value_string>("None"); }}
|
||||
{"string", [](const func_args &) -> value {
|
||||
return mk_val<value_string>("None");
|
||||
}},
|
||||
{"safe", [](const func_args &) -> value {
|
||||
return mk_val<value_string>("None");
|
||||
}},
|
||||
{"strip", [](const func_args &) -> value {
|
||||
return mk_val<value_string>("None");
|
||||
}},
|
||||
{"items", empty_value_fn<value_array>},
|
||||
{"map", empty_value_fn<value_array>},
|
||||
{"reject", empty_value_fn<value_array>},
|
||||
{"rejectattr", empty_value_fn<value_array>},
|
||||
{"select", empty_value_fn<value_array>},
|
||||
{"selectattr", empty_value_fn<value_array>},
|
||||
{"unique", empty_value_fn<value_array>},
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
@@ -1014,10 +1046,33 @@ const func_builtins & value_none_t::get_builtins() const {
|
||||
const func_builtins & value_undefined_t::get_builtins() const {
|
||||
static const func_builtins builtins = {
|
||||
{"default", default_value},
|
||||
{"tojson", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_undefined>();
|
||||
return mk_val<value_string>("null");
|
||||
}},
|
||||
{"capitalize", empty_value_fn<value_string>},
|
||||
{"first", empty_value_fn<value_undefined>},
|
||||
{"items", empty_value_fn<value_array>},
|
||||
{"join", empty_value_fn<value_string>},
|
||||
{"last", empty_value_fn<value_undefined>},
|
||||
{"length", empty_value_fn<value_int>},
|
||||
{"list", empty_value_fn<value_array>},
|
||||
{"lower", empty_value_fn<value_string>},
|
||||
{"map", empty_value_fn<value_array>},
|
||||
{"max", empty_value_fn<value_undefined>},
|
||||
{"min", empty_value_fn<value_undefined>},
|
||||
{"reject", empty_value_fn<value_array>},
|
||||
{"rejectattr", empty_value_fn<value_array>},
|
||||
{"replace", empty_value_fn<value_string>},
|
||||
{"reverse", empty_value_fn<value_array>},
|
||||
{"safe", empty_value_fn<value_string>},
|
||||
{"select", empty_value_fn<value_array>},
|
||||
{"selectattr", empty_value_fn<value_array>},
|
||||
{"sort", empty_value_fn<value_array>},
|
||||
{"string", empty_value_fn<value_string>},
|
||||
{"strip", empty_value_fn<value_string>},
|
||||
{"sum", empty_value_fn<value_int>},
|
||||
{"title", empty_value_fn<value_string>},
|
||||
{"truncate", empty_value_fn<value_string>},
|
||||
{"unique", empty_value_fn<value_array>},
|
||||
{"upper", empty_value_fn<value_string>},
|
||||
{"wordcount", empty_value_fn<value_int>},
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
@@ -1134,6 +1189,8 @@ void global_from_json(context & ctx, const nlohmann::ordered_json & json_obj, bo
|
||||
}
|
||||
}
|
||||
|
||||
// recursively convert value to JSON string
|
||||
// TODO: avoid circular references
|
||||
static void value_to_json_internal(std::ostringstream & oss, const value & val, int curr_lvl, int indent, const std::string_view item_sep, const std::string_view key_sep) {
|
||||
auto indent_str = [indent, curr_lvl]() -> std::string {
|
||||
return (indent > 0) ? std::string(curr_lvl * indent, ' ') : "";
|
||||
@@ -1196,7 +1253,8 @@ static void value_to_json_internal(std::ostringstream & oss, const value & val,
|
||||
size_t i = 0;
|
||||
for (const auto & pair : obj) {
|
||||
oss << indent_str() << (indent > 0 ? std::string(indent, ' ') : "");
|
||||
oss << "\"" << pair.first << "\"" << key_sep;
|
||||
value_to_json_internal(oss, mk_val<value_string>(pair.first->as_string().str()), curr_lvl + 1, indent, item_sep, key_sep);
|
||||
oss << key_sep;
|
||||
value_to_json_internal(oss, pair.second, curr_lvl + 1, indent, item_sep, key_sep);
|
||||
if (i < obj.size() - 1) {
|
||||
oss << item_sep;
|
||||
@@ -1219,4 +1277,19 @@ std::string value_to_json(const value & val, int indent, const std::string_view
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
// TODO: avoid circular references
|
||||
std::string value_to_string_repr(const value & val) {
|
||||
if (is_val<value_string>(val)) {
|
||||
const std::string val_str = val->as_string().str();
|
||||
|
||||
if (val_str.find('\'') != std::string::npos) {
|
||||
return value_to_json(val);
|
||||
} else {
|
||||
return "'" + val_str + "'";
|
||||
}
|
||||
} else {
|
||||
return val->as_repr();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jinja
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "string.h"
|
||||
#include "utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
@@ -10,6 +12,7 @@
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace jinja {
|
||||
@@ -93,7 +96,8 @@ void global_from_json(context & ctx, const T_JSON & json_obj, bool mark_input);
|
||||
|
||||
struct func_args; // function argument values
|
||||
|
||||
using func_handler = std::function<value(const func_args &)>;
|
||||
using func_hptr = value(const func_args &);
|
||||
using func_handler = std::function<func_hptr>;
|
||||
using func_builtins = std::map<std::string, func_handler>;
|
||||
|
||||
enum value_compare_op { eq, ge, gt, lt, ne };
|
||||
@@ -103,28 +107,9 @@ struct value_t {
|
||||
int64_t val_int;
|
||||
double val_flt;
|
||||
string val_str;
|
||||
bool val_bool;
|
||||
|
||||
std::vector<value> val_arr;
|
||||
|
||||
struct map {
|
||||
// once set to true, all keys must be numeric
|
||||
// caveat: we only allow either all numeric keys or all non-numeric keys
|
||||
// for now, this only applied to for_statement in case of iterating over object keys/items
|
||||
bool is_key_numeric = false;
|
||||
std::map<std::string, value> unordered;
|
||||
std::vector<std::pair<std::string, value>> ordered;
|
||||
void insert(const std::string & key, const value & val) {
|
||||
if (unordered.find(key) != unordered.end()) {
|
||||
// if key exists, remove from ordered list
|
||||
ordered.erase(std::remove_if(ordered.begin(), ordered.end(),
|
||||
[&](const std::pair<std::string, value> & p) { return p.first == key; }),
|
||||
ordered.end());
|
||||
}
|
||||
unordered[key] = val;
|
||||
ordered.push_back({key, val});
|
||||
}
|
||||
} val_obj;
|
||||
std::vector<std::pair<value, value>> val_obj;
|
||||
|
||||
func_handler val_func;
|
||||
|
||||
@@ -139,6 +124,7 @@ struct value_t {
|
||||
value_t(const value_t &) = default;
|
||||
virtual ~value_t() = default;
|
||||
|
||||
// Note: only for debugging and error reporting purposes
|
||||
virtual std::string type() const { return ""; }
|
||||
|
||||
virtual int64_t as_int() const { throw std::runtime_error(type() + " is not an int value"); }
|
||||
@@ -146,7 +132,7 @@ struct value_t {
|
||||
virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); }
|
||||
virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); }
|
||||
virtual const std::vector<value> & as_array() const { throw std::runtime_error(type() + " is not an array value"); }
|
||||
virtual const std::vector<std::pair<std::string, value>> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual const std::vector<std::pair<value, value>> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); }
|
||||
virtual bool is_none() const { return false; }
|
||||
virtual bool is_undefined() const { return false; }
|
||||
@@ -154,43 +140,66 @@ struct value_t {
|
||||
throw std::runtime_error("No builtins available for type " + type());
|
||||
}
|
||||
|
||||
virtual bool has_key(const std::string & key) {
|
||||
return val_obj.unordered.find(key) != val_obj.unordered.end();
|
||||
}
|
||||
virtual value & at(const std::string & key, value & default_val) {
|
||||
auto it = val_obj.unordered.find(key);
|
||||
if (it == val_obj.unordered.end()) {
|
||||
return default_val;
|
||||
}
|
||||
return val_obj.unordered.at(key);
|
||||
}
|
||||
virtual value & at(const std::string & key) {
|
||||
auto it = val_obj.unordered.find(key);
|
||||
if (it == val_obj.unordered.end()) {
|
||||
throw std::runtime_error("Key '" + key + "' not found in value of type " + type());
|
||||
}
|
||||
return val_obj.unordered.at(key);
|
||||
}
|
||||
virtual value & at(int64_t index, value & default_val) {
|
||||
if (index < 0) {
|
||||
index += val_arr.size();
|
||||
}
|
||||
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
|
||||
return default_val;
|
||||
}
|
||||
return val_arr[index];
|
||||
}
|
||||
virtual value & at(int64_t index) {
|
||||
if (index < 0) {
|
||||
index += val_arr.size();
|
||||
}
|
||||
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
|
||||
throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size()));
|
||||
}
|
||||
return val_arr[index];
|
||||
}
|
||||
virtual bool has_key(const value &) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual void insert(const value & /* key */, const value & /* val */) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value & at(const value & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value & at(const value & /* key */) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value & at(const std::string & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value & at(const std::string & /* key */) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value & at(int64_t /* idx */, value & /* default_val */) { throw std::runtime_error(type() + " is not an array value"); }
|
||||
virtual value & at(int64_t /* idx */) { throw std::runtime_error(type() + " is not an array value"); }
|
||||
|
||||
virtual bool is_numeric() const { return false; }
|
||||
virtual bool is_hashable() const { return false; }
|
||||
virtual bool is_immutable() const { return true; }
|
||||
virtual hasher unique_hash() const noexcept = 0;
|
||||
// TODO: C++20 <=> operator
|
||||
// NOTE: We are treating == as equivalent (for normal comparisons) and != as strict nonequal (for strict (is) comparisons)
|
||||
virtual bool operator==(const value_t & other) const { return equivalent(other); }
|
||||
virtual bool operator!=(const value_t & other) const { return nonequal(other); }
|
||||
|
||||
// Note: only for debugging purposes
|
||||
virtual std::string as_repr() const { return as_string().str(); }
|
||||
|
||||
protected:
|
||||
virtual bool equivalent(const value_t &) const = 0;
|
||||
virtual bool nonequal(const value_t & other) const { return !equivalent(other); }
|
||||
};
|
||||
|
||||
//
|
||||
// utils
|
||||
//
|
||||
|
||||
const func_builtins & global_builtins();
|
||||
|
||||
std::string value_to_json(const value & val, int indent = -1, const std::string_view item_sep = ", ", const std::string_view key_sep = ": ");
|
||||
|
||||
// Note: only used for debugging purposes
|
||||
std::string value_to_string_repr(const value & val);
|
||||
|
||||
struct not_implemented_exception : public std::runtime_error {
|
||||
not_implemented_exception(const std::string & msg) : std::runtime_error("NotImplemented: " + msg) {}
|
||||
};
|
||||
|
||||
struct value_hasher {
|
||||
size_t operator()(const value & val) const noexcept {
|
||||
return val->unique_hash().digest();
|
||||
}
|
||||
};
|
||||
|
||||
struct value_equivalence {
|
||||
bool operator()(const value & lhs, const value & rhs) const {
|
||||
return *lhs == *rhs;
|
||||
}
|
||||
bool operator()(const std::pair<value, value> & lhs, const std::pair<value, value> & rhs) const {
|
||||
return *(lhs.first) == *(rhs.first) && *(lhs.second) == *(rhs.second);
|
||||
}
|
||||
};
|
||||
|
||||
struct value_equality {
|
||||
bool operator()(const value & lhs, const value & rhs) const {
|
||||
return !(*lhs != *rhs);
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
@@ -198,24 +207,49 @@ struct value_t {
|
||||
//
|
||||
|
||||
struct value_int_t : public value_t {
|
||||
value_int_t(int64_t v) { val_int = v; }
|
||||
value_int_t(int64_t v) {
|
||||
val_int = v;
|
||||
val_flt = static_cast<double>(v);
|
||||
if (static_cast<int64_t>(val_flt) != v) {
|
||||
val_flt = v < 0 ? -INFINITY : INFINITY;
|
||||
}
|
||||
}
|
||||
virtual std::string type() const override { return "Integer"; }
|
||||
virtual int64_t as_int() const override { return val_int; }
|
||||
virtual double as_float() const override { return static_cast<double>(val_int); }
|
||||
virtual double as_float() const override { return val_flt; }
|
||||
virtual string as_string() const override { return std::to_string(val_int); }
|
||||
virtual bool as_bool() const override {
|
||||
return val_int != 0;
|
||||
}
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
virtual bool is_numeric() const override { return true; }
|
||||
virtual bool is_hashable() const override { return true; }
|
||||
virtual hasher unique_hash() const noexcept override {
|
||||
return hasher(typeid(*this))
|
||||
.update(&val_int, sizeof(val_int))
|
||||
.update(&val_flt, sizeof(val_flt));
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt;
|
||||
}
|
||||
virtual bool nonequal(const value_t & other) const override {
|
||||
return !(typeid(*this) == typeid(other) && val_int == other.val_int);
|
||||
}
|
||||
};
|
||||
using value_int = std::shared_ptr<value_int_t>;
|
||||
|
||||
|
||||
struct value_float_t : public value_t {
|
||||
value_float_t(double v) { val_flt = v; }
|
||||
value val;
|
||||
value_float_t(double v) {
|
||||
val_flt = v;
|
||||
val_int = std::isfinite(v) ? static_cast<int64_t>(v) : 0;
|
||||
val = mk_val<value_int>(val_int);
|
||||
}
|
||||
virtual std::string type() const override { return "Float"; }
|
||||
virtual double as_float() const override { return val_flt; }
|
||||
virtual int64_t as_int() const override { return static_cast<int64_t>(val_flt); }
|
||||
virtual int64_t as_int() const override { return val_int; }
|
||||
virtual string as_string() const override {
|
||||
std::string out = std::to_string(val_flt);
|
||||
out.erase(out.find_last_not_of('0') + 1, std::string::npos); // remove trailing zeros
|
||||
@@ -226,6 +260,24 @@ struct value_float_t : public value_t {
|
||||
return val_flt != 0.0;
|
||||
}
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
virtual bool is_numeric() const override { return true; }
|
||||
virtual bool is_hashable() const override { return true; }
|
||||
virtual hasher unique_hash() const noexcept override {
|
||||
if (static_cast<double>(val_int) == val_flt) {
|
||||
return val->unique_hash();
|
||||
} else {
|
||||
return hasher(typeid(*this))
|
||||
.update(&val_int, sizeof(val_int))
|
||||
.update(&val_flt, sizeof(val_flt));
|
||||
}
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt;
|
||||
}
|
||||
virtual bool nonequal(const value_t & other) const override {
|
||||
return !(typeid(*this) == typeid(other) && val_flt == other.val_flt);
|
||||
}
|
||||
};
|
||||
using value_float = std::shared_ptr<value_float_t>;
|
||||
|
||||
@@ -247,19 +299,49 @@ struct value_string_t : public value_t {
|
||||
return val_str.length() > 0;
|
||||
}
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
virtual bool is_hashable() const override { return true; }
|
||||
virtual hasher unique_hash() const noexcept override {
|
||||
const auto type_hash = typeid(*this).hash_code();
|
||||
auto hash = hasher();
|
||||
hash.update(&type_hash, sizeof(type_hash));
|
||||
val_str.hash_update(hash);
|
||||
return hash;
|
||||
}
|
||||
void mark_input() {
|
||||
val_str.mark_input();
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return typeid(*this) == typeid(other) && val_str.str() == other.val_str.str();
|
||||
}
|
||||
};
|
||||
using value_string = std::shared_ptr<value_string_t>;
|
||||
|
||||
|
||||
struct value_bool_t : public value_t {
|
||||
value_bool_t(bool v) { val_bool = v; }
|
||||
value val;
|
||||
value_bool_t(bool v) {
|
||||
val_int = static_cast<int64_t>(v);
|
||||
val_flt = static_cast<double>(v);
|
||||
val = mk_val<value_int>(val_int);
|
||||
}
|
||||
virtual std::string type() const override { return "Boolean"; }
|
||||
virtual bool as_bool() const override { return val_bool; }
|
||||
virtual string as_string() const override { return std::string(val_bool ? "True" : "False"); }
|
||||
virtual int64_t as_int() const override { return val_int; }
|
||||
virtual bool as_bool() const override { return val_int; }
|
||||
virtual string as_string() const override { return std::string(val_int ? "True" : "False"); }
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
virtual bool is_numeric() const override { return true; }
|
||||
virtual bool is_hashable() const override { return true; }
|
||||
virtual hasher unique_hash() const noexcept override {
|
||||
return val->unique_hash();
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt;
|
||||
}
|
||||
virtual bool nonequal(const value_t & other) const override {
|
||||
return !(typeid(*this) == typeid(other) && val_int == other.val_int);
|
||||
}
|
||||
};
|
||||
using value_bool = std::shared_ptr<value_bool_t>;
|
||||
|
||||
@@ -269,13 +351,34 @@ struct value_array_t : public value_t {
|
||||
value_array_t(value & v) {
|
||||
val_arr = v->val_arr;
|
||||
}
|
||||
value_array_t(std::vector<value> && arr) {
|
||||
val_arr = arr;
|
||||
}
|
||||
value_array_t(const std::vector<value> & arr) {
|
||||
val_arr = arr;
|
||||
}
|
||||
void reverse() { std::reverse(val_arr.begin(), val_arr.end()); }
|
||||
void push_back(const value & val) { val_arr.push_back(val); }
|
||||
void push_back(value && val) { val_arr.push_back(std::move(val)); }
|
||||
void reverse() {
|
||||
if (is_immutable()) {
|
||||
throw std::runtime_error("Attempting to modify immutable type");
|
||||
}
|
||||
std::reverse(val_arr.begin(), val_arr.end());
|
||||
}
|
||||
void push_back(const value & val) {
|
||||
if (is_immutable()) {
|
||||
throw std::runtime_error("Attempting to modify immutable type");
|
||||
}
|
||||
val_arr.push_back(val);
|
||||
}
|
||||
void push_back(value && val) {
|
||||
if (is_immutable()) {
|
||||
throw std::runtime_error("Attempting to modify immutable type");
|
||||
}
|
||||
val_arr.push_back(std::move(val));
|
||||
}
|
||||
value pop_at(int64_t index) {
|
||||
if (is_immutable()) {
|
||||
throw std::runtime_error("Attempting to modify immutable type");
|
||||
}
|
||||
if (index < 0) {
|
||||
index = static_cast<int64_t>(val_arr.size()) + index;
|
||||
}
|
||||
@@ -287,64 +390,225 @@ struct value_array_t : public value_t {
|
||||
return val;
|
||||
}
|
||||
virtual std::string type() const override { return "Array"; }
|
||||
virtual bool is_immutable() const override { return false; }
|
||||
virtual const std::vector<value> & as_array() const override { return val_arr; }
|
||||
virtual string as_string() const override {
|
||||
const bool immutable = is_immutable();
|
||||
std::ostringstream ss;
|
||||
ss << "[";
|
||||
ss << (immutable ? "(" : "[");
|
||||
for (size_t i = 0; i < val_arr.size(); i++) {
|
||||
if (i > 0) ss << ", ";
|
||||
ss << val_arr.at(i)->as_repr();
|
||||
value val = val_arr.at(i);
|
||||
ss << value_to_string_repr(val);
|
||||
}
|
||||
ss << "]";
|
||||
if (immutable && val_arr.size() == 1) {
|
||||
ss << ",";
|
||||
}
|
||||
ss << (immutable ? ")" : "]");
|
||||
return ss.str();
|
||||
}
|
||||
virtual bool as_bool() const override {
|
||||
return !val_arr.empty();
|
||||
}
|
||||
virtual value & at(int64_t index, value & default_val) override {
|
||||
if (index < 0) {
|
||||
index += val_arr.size();
|
||||
}
|
||||
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
|
||||
return default_val;
|
||||
}
|
||||
return val_arr[index];
|
||||
}
|
||||
virtual value & at(int64_t index) override {
|
||||
if (index < 0) {
|
||||
index += val_arr.size();
|
||||
}
|
||||
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
|
||||
throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size()));
|
||||
}
|
||||
return val_arr[index];
|
||||
}
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
virtual bool is_hashable() const override {
|
||||
if (std::all_of(val_arr.begin(), val_arr.end(), [&](auto & val) -> bool {
|
||||
return val->is_immutable() && val->is_hashable();
|
||||
})) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
virtual hasher unique_hash() const noexcept override {
|
||||
auto hash = hasher(typeid(*this));
|
||||
for (const auto & val : val_arr) {
|
||||
// must use digest to prevent problems from "concatenation" property of hasher
|
||||
// for ex. hash of [ "ab", "c" ] should be different from [ "a", "bc" ]
|
||||
const size_t val_hash = val->unique_hash().digest();
|
||||
hash.update(&val_hash, sizeof(size_t));
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), value_equivalence());
|
||||
}
|
||||
};
|
||||
using value_array = std::shared_ptr<value_array_t>;
|
||||
|
||||
|
||||
struct value_tuple_t : public value_array_t {
|
||||
value_tuple_t(value & v) {
|
||||
val_arr = v->val_arr;
|
||||
}
|
||||
value_tuple_t(std::vector<value> && arr) {
|
||||
val_arr = arr;
|
||||
}
|
||||
value_tuple_t(const std::vector<value> & arr) {
|
||||
val_arr = arr;
|
||||
}
|
||||
value_tuple_t(const std::pair<value, value> & pair) {
|
||||
val_arr.push_back(pair.first);
|
||||
val_arr.push_back(pair.second);
|
||||
}
|
||||
virtual std::string type() const override { return "Tuple"; }
|
||||
virtual bool is_immutable() const override { return true; }
|
||||
};
|
||||
using value_tuple = std::shared_ptr<value_tuple_t>;
|
||||
|
||||
|
||||
struct value_object_t : public value_t {
|
||||
std::unordered_map<value, value, value_hasher, value_equivalence> unordered;
|
||||
bool has_builtins = true; // context and loop objects do not have builtins
|
||||
value_object_t() = default;
|
||||
value_object_t(value & v) {
|
||||
val_obj = v->val_obj;
|
||||
}
|
||||
value_object_t(const std::map<std::string, value> & obj) {
|
||||
for (const auto & pair : obj) {
|
||||
val_obj.insert(pair.first, pair.second);
|
||||
for (const auto & pair : val_obj) {
|
||||
unordered[pair.first] = pair.second;
|
||||
}
|
||||
}
|
||||
value_object_t(const std::vector<std::pair<std::string, value>> & obj) {
|
||||
value_object_t(const std::map<value, value> & obj) {
|
||||
for (const auto & pair : obj) {
|
||||
val_obj.insert(pair.first, pair.second);
|
||||
insert(pair.first, pair.second);
|
||||
}
|
||||
}
|
||||
value_object_t(const std::vector<std::pair<value, value>> & obj) {
|
||||
for (const auto & pair : obj) {
|
||||
insert(pair.first, pair.second);
|
||||
}
|
||||
}
|
||||
void insert(const std::string & key, const value & val) {
|
||||
val_obj.insert(key, val);
|
||||
insert(mk_val<value_string>(key), val);
|
||||
}
|
||||
virtual std::string type() const override { return "Object"; }
|
||||
virtual const std::vector<std::pair<std::string, value>> & as_ordered_object() const override { return val_obj.ordered; }
|
||||
virtual bool is_immutable() const override { return false; }
|
||||
virtual const std::vector<std::pair<value, value>> & as_ordered_object() const override { return val_obj; }
|
||||
virtual string as_string() const override {
|
||||
std::ostringstream ss;
|
||||
ss << "{";
|
||||
for (size_t i = 0; i < val_obj.size(); i++) {
|
||||
if (i > 0) ss << ", ";
|
||||
auto & [key, val] = val_obj.at(i);
|
||||
ss << value_to_string_repr(key) << ": " << value_to_string_repr(val);
|
||||
}
|
||||
ss << "}";
|
||||
return ss.str();
|
||||
}
|
||||
virtual bool as_bool() const override {
|
||||
return !val_obj.unordered.empty();
|
||||
return !unordered.empty();
|
||||
}
|
||||
virtual bool has_key(const value & key) override {
|
||||
if (!key->is_immutable() || !key->is_hashable()) {
|
||||
throw std::runtime_error("Object key of unhashable type: " + key->type());
|
||||
}
|
||||
return unordered.find(key) != unordered.end();
|
||||
}
|
||||
virtual void insert(const value & key, const value & val) override {
|
||||
bool replaced = false;
|
||||
if (is_immutable()) {
|
||||
throw std::runtime_error("Attempting to modify immutable type");
|
||||
}
|
||||
if (has_key(key)) {
|
||||
// if key exists, replace value in ordered list instead of appending
|
||||
for (auto & pair : val_obj) {
|
||||
if (*(pair.first) == *key) {
|
||||
pair.second = val;
|
||||
replaced = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
unordered[key] = val;
|
||||
if (!replaced) {
|
||||
val_obj.push_back({key, val});
|
||||
}
|
||||
}
|
||||
virtual value & at(const value & key, value & default_val) override {
|
||||
if (!has_key(key)) {
|
||||
return default_val;
|
||||
}
|
||||
return unordered.at(key);
|
||||
}
|
||||
virtual value & at(const value & key) override {
|
||||
if (!has_key(key)) {
|
||||
throw std::runtime_error("Key '" + key->as_string().str() + "' not found in value of type " + type());
|
||||
}
|
||||
return unordered.at(key);
|
||||
}
|
||||
virtual value & at(const std::string & key, value & default_val) override {
|
||||
value key_val = mk_val<value_string>(key);
|
||||
return at(key_val, default_val);
|
||||
}
|
||||
virtual value & at(const std::string & key) override {
|
||||
value key_val = mk_val<value_string>(key);
|
||||
return at(key_val);
|
||||
}
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
virtual bool is_hashable() const override {
|
||||
if (std::all_of(val_obj.begin(), val_obj.end(), [&](auto & pair) -> bool {
|
||||
const auto & val = pair.second;
|
||||
return val->is_immutable() && val->is_hashable();
|
||||
})) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
virtual hasher unique_hash() const noexcept override {
|
||||
auto hash = hasher(typeid(*this));
|
||||
for (const auto & [key, val] : val_obj) {
|
||||
// must use digest to prevent problems from "concatenation" property of hasher
|
||||
// for ex. hash of key="ab", value="c" should be different from key="a", value="bc"
|
||||
const size_t key_hash = key->unique_hash().digest();
|
||||
const size_t val_hash = val->unique_hash().digest();
|
||||
hash.update(&key_hash, sizeof(key_hash));
|
||||
hash.update(&val_hash, sizeof(val_hash));
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), value_equivalence());
|
||||
}
|
||||
};
|
||||
using value_object = std::shared_ptr<value_object_t>;
|
||||
|
||||
//
|
||||
// null and undefined types
|
||||
// none and undefined types
|
||||
//
|
||||
|
||||
struct value_none_t : public value_t {
|
||||
virtual std::string type() const override { return "None"; }
|
||||
virtual bool is_none() const override { return true; }
|
||||
virtual bool as_bool() const override { return false; }
|
||||
virtual string as_string() const override { return string("None"); }
|
||||
virtual string as_string() const override { return string(type()); }
|
||||
virtual std::string as_repr() const override { return type(); }
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
virtual bool is_hashable() const override { return true; }
|
||||
virtual hasher unique_hash() const noexcept override {
|
||||
return hasher(typeid(*this));
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return typeid(*this) == typeid(other);
|
||||
}
|
||||
};
|
||||
using value_none = std::shared_ptr<value_none_t>;
|
||||
|
||||
@@ -356,6 +620,13 @@ struct value_undefined_t : public value_t {
|
||||
virtual bool as_bool() const override { return false; }
|
||||
virtual std::string as_repr() const override { return type(); }
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
virtual hasher unique_hash() const noexcept override {
|
||||
return hasher(typeid(*this));
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return is_undefined() == other.is_undefined();
|
||||
}
|
||||
};
|
||||
using value_undefined = std::shared_ptr<value_undefined_t>;
|
||||
|
||||
@@ -436,7 +707,23 @@ struct value_func_t : public value_t {
|
||||
return val_func(new_args);
|
||||
}
|
||||
virtual std::string type() const override { return "Function"; }
|
||||
virtual std::string as_repr() const override { return type(); }
|
||||
virtual std::string as_repr() const override { return type() + "<" + name + ">(" + (arg0 ? arg0->as_repr() : "") + ")"; }
|
||||
virtual bool is_hashable() const override { return false; }
|
||||
virtual hasher unique_hash() const noexcept override {
|
||||
// Note: this is unused for now, we don't support function as object keys
|
||||
// use function pointer as unique identifier
|
||||
const auto target = val_func.target<func_hptr>();
|
||||
return hasher(typeid(*this)).update(&target, sizeof(target));
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
// Note: this is unused for now, we don't support function as object keys
|
||||
// compare function pointers
|
||||
// (val_func == other.val_func does not work as std::function::operator== is only used for nullptr check)
|
||||
const auto target_this = this->val_func.target<func_hptr>();
|
||||
const auto target_other = other.val_func.target<func_hptr>();
|
||||
return typeid(*this) == typeid(other) && target_this == target_other;
|
||||
}
|
||||
};
|
||||
using value_func = std::shared_ptr<value_func_t>;
|
||||
|
||||
@@ -447,18 +734,21 @@ struct value_kwarg_t : public value_t {
|
||||
value_kwarg_t(const std::string & k, const value & v) : key(k), val(v) {}
|
||||
virtual std::string type() const override { return "KwArg"; }
|
||||
virtual std::string as_repr() const override { return type(); }
|
||||
virtual bool is_hashable() const override { return true; }
|
||||
virtual hasher unique_hash() const noexcept override {
|
||||
const auto type_hash = typeid(*this).hash_code();
|
||||
auto hash = val->unique_hash();
|
||||
hash.update(&type_hash, sizeof(type_hash))
|
||||
.update(key.data(), key.size());
|
||||
return hash;
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
const value_kwarg_t & other_val = static_cast<const value_kwarg_t &>(other);
|
||||
return typeid(*this) == typeid(other) && key == other_val.key && val == other_val.val;
|
||||
}
|
||||
};
|
||||
using value_kwarg = std::shared_ptr<value_kwarg_t>;
|
||||
|
||||
|
||||
// utils
|
||||
|
||||
const func_builtins & global_builtins();
|
||||
std::string value_to_json(const value & val, int indent = -1, const std::string_view item_sep = ", ", const std::string_view key_sep = ": ");
|
||||
|
||||
struct not_implemented_exception : public std::runtime_error {
|
||||
not_implemented_exception(const std::string & msg) : std::runtime_error("NotImplemented: " + msg) {}
|
||||
};
|
||||
|
||||
|
||||
} // namespace jinja
|
||||
|
||||
@@ -192,12 +192,12 @@ void common_ngram_cache_draft(
|
||||
break;
|
||||
}
|
||||
|
||||
LOG(" - draft candidate: token=%d\n", drafted_token);
|
||||
LOG_DBG(" - draft candidate: token=%d\n", drafted_token);
|
||||
draft.push_back(drafted_token);
|
||||
}
|
||||
}
|
||||
|
||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename) {
|
||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename) {
|
||||
std::ofstream file_out(filename, std::ios::binary);
|
||||
for (std::pair<common_ngram, common_ngram_cache_part> item : ngram_cache) {
|
||||
const common_ngram ngram = item.first;
|
||||
@@ -217,10 +217,9 @@ void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & fil
|
||||
file_out.write(reinterpret_cast<const char *>(&count), sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
common_ngram_cache common_ngram_cache_load(std::string & filename) {
|
||||
common_ngram_cache common_ngram_cache_load(const std::string & filename) {
|
||||
std::ifstream hashmap_file(filename, std::ios::binary);
|
||||
if (!hashmap_file) {
|
||||
throw std::ifstream::failure("Unable to open file " + filename);
|
||||
|
||||
@@ -88,12 +88,12 @@ void common_ngram_cache_draft(
|
||||
// Save an ngram cache to a file.
|
||||
// ngram_cache: the ngram cache to save.
|
||||
// filename: the path under which to save the ngram cache.
|
||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename);
|
||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename);
|
||||
|
||||
// Load an ngram cache saved with common_ngram_cache_save.
|
||||
// filename: the path from which to load the ngram cache.
|
||||
// returns: an ngram cache containing the information saved to filename.
|
||||
common_ngram_cache common_ngram_cache_load(std::string & filename);
|
||||
common_ngram_cache common_ngram_cache_load(const std::string & filename);
|
||||
|
||||
// Merge two ngram caches.
|
||||
// ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add.
|
||||
|
||||
362
common/ngram-map.cpp
Normal file
362
common/ngram-map.cpp
Normal file
@@ -0,0 +1,362 @@
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "ngram-map.h"
|
||||
|
||||
#include <cinttypes>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <sstream>
|
||||
|
||||
// Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...].
|
||||
static std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length) {
|
||||
std::ostringstream oss;
|
||||
oss << '[';
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
if (i > 0) {
|
||||
oss << ", ";
|
||||
}
|
||||
oss << inp[start + i];
|
||||
}
|
||||
oss << ']';
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
||||
// n-gram simple
|
||||
//
|
||||
|
||||
/**
|
||||
* Perform speculative generation using the model's own token history.
|
||||
* Searches for a matching pattern in the token history and returns draft tokens.
|
||||
*
|
||||
* @param state Current state of this implementation
|
||||
* @param tokens Token history to search in
|
||||
* @param sampled Last sampled token
|
||||
* @return Vector of draft tokens, empty if no matching pattern is found
|
||||
*/
|
||||
llama_tokens common_ngram_simple_draft(
|
||||
common_ngram_simple_state & state,
|
||||
const llama_tokens & tokens, llama_token sampled) {
|
||||
|
||||
// Simple implementation of self-speculative decoding without a draft model.
|
||||
//
|
||||
const size_t cur_len = tokens.size();
|
||||
// Only check every check_rate tokens to save compute
|
||||
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
|
||||
if (state.idx_last_check + state.config.check_rate > cur_len) {
|
||||
llama_tokens draft_tokens;
|
||||
return draft_tokens;
|
||||
}
|
||||
|
||||
size_t n_draft_min = state.config.size_ngram; // size of n-gram to lookup in token history
|
||||
size_t n_draft_max = state.config.size_mgram; // the m-gram following the found n-gram is used for draft
|
||||
|
||||
// vector for tokens we want to verify.
|
||||
// return empty vector if there is no match.
|
||||
llama_tokens draft_tokens;
|
||||
|
||||
// We need at least n_draft_min + n_draft_max + 1 tokens.
|
||||
if (cur_len <= static_cast<size_t>(n_draft_min + n_draft_max + 1)) {
|
||||
return draft_tokens;
|
||||
}
|
||||
|
||||
// pattern search
|
||||
llama_tokens pattern;
|
||||
pattern.reserve(n_draft_min);
|
||||
for (size_t j = cur_len - n_draft_min + 1; j < cur_len; ++j) {
|
||||
pattern.push_back(tokens[j]);
|
||||
}
|
||||
pattern.push_back(sampled); // add the last token to the pattern
|
||||
|
||||
// We do a search in the token history.
|
||||
state.idx_last_check = cur_len;
|
||||
|
||||
size_t match_pos = 0; // we ignore position 0, position 0 == no match
|
||||
// search backwards, but skip the current match (we are currently there)
|
||||
for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) {
|
||||
bool match = true;
|
||||
for (size_t k = 0; k < pattern.size(); ++k) {
|
||||
if (tokens[j + k] != pattern[k]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
match_pos = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match_pos == 0) {
|
||||
return draft_tokens;
|
||||
}
|
||||
|
||||
const size_t copy_max = std::min(
|
||||
n_draft_max,
|
||||
cur_len - (match_pos + n_draft_min)
|
||||
);
|
||||
if (copy_max < n_draft_min) {
|
||||
return draft_tokens;
|
||||
}
|
||||
LOG_DBG("%s: #tokens = %zu: found matching pattern at pos %zu, length %zu, draft length %zu\n",
|
||||
__func__, cur_len,
|
||||
match_pos, pattern.size(), copy_max);
|
||||
|
||||
draft_tokens.reserve(copy_max);
|
||||
for (size_t j = 0; j < copy_max; ++j) {
|
||||
draft_tokens.push_back(tokens[match_pos + n_draft_min + j]);
|
||||
}
|
||||
return draft_tokens;
|
||||
}
|
||||
|
||||
|
||||
// n-gram map
|
||||
//
|
||||
|
||||
// maximum number of counted values of a ngram map value.
|
||||
#define COMMON_NGRAM_MAX_VALUE_COUNT 16380
|
||||
|
||||
void common_ngram_map_draft(common_ngram_map & map,
|
||||
const llama_tokens & inp, llama_token sampled,
|
||||
llama_tokens & draft) {
|
||||
// reset last key and value.
|
||||
map.last_draft_created = false;
|
||||
map.last_draft_key_idx = 0;
|
||||
map.last_draft_value_idx = 0;
|
||||
|
||||
const size_t cur_len = inp.size();
|
||||
const uint16_t n = map.size_key;
|
||||
const uint16_t m = map.size_value;
|
||||
if (cur_len < static_cast<size_t>(2 * n + m)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Only check every check_rate tokens to save compute
|
||||
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
|
||||
if (map.idx_last_check + map.check_rate > cur_len) {
|
||||
return;
|
||||
}
|
||||
map.idx_last_check = cur_len;
|
||||
|
||||
// search pattern, the key n-gram
|
||||
std::vector<llama_token> key_tokens;
|
||||
key_tokens.reserve(n);
|
||||
for (size_t j = cur_len - n + 1; j < cur_len; ++j) {
|
||||
key_tokens.push_back(inp[j]);
|
||||
}
|
||||
key_tokens.push_back(sampled);
|
||||
|
||||
// search for the key in the map
|
||||
size_t match_pos = 0;
|
||||
for (size_t j = cur_len - n - m - 1; j > 0; --j) {
|
||||
bool match = true;
|
||||
for (size_t k = 0; k < n; ++k) {
|
||||
if (inp[j + k] != key_tokens[k]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
match_pos = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match_pos > 0) {
|
||||
LOG_INF("%s: cur_len = %zu, n = %d, m = %d, sz_tkns = %zu, sampled = %d, match_pos = %zu\n", __func__,
|
||||
cur_len, n, m, key_tokens.size(), sampled, match_pos);
|
||||
}
|
||||
|
||||
if (match_pos == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// We have a match, now we look for the statistics of the key.
|
||||
size_t key_offset = map.keys.size(); // offset in the map
|
||||
// We iterate through the std::vector<common_ngram_map_key> map->keys.
|
||||
for (size_t i = 0; i < map.keys.size(); ++i) {
|
||||
bool match = true;
|
||||
for (size_t j = 0; j < n; ++j) {
|
||||
if (inp[map.keys[i].key_idx + j] != key_tokens[j]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
key_offset = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (key_offset == map.keys.size()) {
|
||||
// We create a new key-entry, it will get offset key_offset.
|
||||
common_ngram_map_key new_key;
|
||||
new_key.key_idx = match_pos;
|
||||
new_key.stat_idx = 0;
|
||||
new_key.key_num = 0;
|
||||
for (int i = 0; i < COMMON_NGRAM_MAX_VALUES; ++i) {
|
||||
new_key.values[i].value_num = 0;
|
||||
new_key.values[i].n_accepted = m;
|
||||
}
|
||||
map.keys.push_back(new_key);
|
||||
}
|
||||
|
||||
// our key n-gram:
|
||||
common_ngram_map_key & curr_key = map.keys[key_offset];
|
||||
|
||||
// update number of key hits
|
||||
curr_key.key_num = (uint16_t) std::min((int) map.keys[key_offset].key_num + 1,
|
||||
(int) COMMON_NGRAM_MAX_VALUE_COUNT);
|
||||
|
||||
if (map.key_only) {
|
||||
// simple mode:
|
||||
// Fill in the draft with the m tokens following the key.
|
||||
// We work with value values[0] only.
|
||||
int n_draft_tokens = std::min((int) m, (int) curr_key.values[0].n_accepted);
|
||||
|
||||
for (int i = 0; i < n_draft_tokens; ++i) {
|
||||
draft.push_back(inp[match_pos + n + i]);
|
||||
}
|
||||
|
||||
LOG_INF("%s: key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
|
||||
key_offset, curr_key.key_num, draft.size());
|
||||
|
||||
map.last_draft_created = false;
|
||||
map.last_draft_key_idx = key_offset;
|
||||
map.last_draft_value_idx = 0; // value 0 is used for simple mode
|
||||
return;
|
||||
}
|
||||
|
||||
if (curr_key.key_num < map.min_hits) {
|
||||
// not enough hits to consider this a good draft
|
||||
LOG_DBG("%s: key_offset = %zu, key_num = %d, min_hits = %d, no draft\n", __func__,
|
||||
key_offset, curr_key.key_num, map.min_hits);
|
||||
return;
|
||||
}
|
||||
|
||||
// complex mode: examine the different m-grams after this key n-gram.
|
||||
//
|
||||
|
||||
// determine all (max COMMON_NGRAM_MAX_VALUES) m-grams after the key n-gram.
|
||||
for (size_t i = curr_key.stat_idx; i <= match_pos; ++i) {
|
||||
// begins the key n-gram at index i?
|
||||
bool match_key = true;
|
||||
for (size_t k = 0; k < n; ++k) {
|
||||
if (inp[i + k] != key_tokens[k]) {
|
||||
match_key = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!match_key) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Do we haven a existing value m-gram or a new one after the key at index i?
|
||||
size_t idx_begin_value_key = i + n;
|
||||
int idx_value = -1;
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
size_t idx_begin_value_v = curr_key.values[v].value_idx;
|
||||
if (idx_begin_value_v == 0) {
|
||||
// We found an empty value slot => we found a new value m-gram after the key n-gram.
|
||||
curr_key.values[v].value_idx = idx_begin_value_key;
|
||||
curr_key.values[v].value_num = 0;
|
||||
curr_key.values[v].n_accepted = m;
|
||||
idx_value = v;
|
||||
break;
|
||||
}
|
||||
bool match = true;
|
||||
for (size_t j = 0; j < m; ++j) {
|
||||
if (inp[idx_begin_value_key + j] != inp[idx_begin_value_v + j]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
// We found an existing value m-gram after the key n-gram.
|
||||
idx_value = v;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (idx_value >= 0) {
|
||||
// We found a value m-gram of the key n-gram.
|
||||
curr_key.values[idx_value].value_num = (uint16_t) std::min((int) curr_key.values[idx_value].value_num + 1,
|
||||
(int) COMMON_NGRAM_MAX_VALUE_COUNT);
|
||||
}
|
||||
}
|
||||
// the statistics are updated up to match_pos.
|
||||
curr_key.stat_idx = match_pos;
|
||||
|
||||
// Do we have a value we could use for the draft?
|
||||
uint16_t max_occur = 0;
|
||||
int slot_max = 0;
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
uint16_t curr_occur = curr_key.values[v].value_num;
|
||||
if (curr_occur > max_occur) {
|
||||
max_occur = curr_occur;
|
||||
slot_max = v;
|
||||
}
|
||||
}
|
||||
// What is sum of the other occurences?
|
||||
uint32_t sum_occur = 0;
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
if (v == slot_max) {
|
||||
continue;
|
||||
}
|
||||
uint16_t curr_occur = curr_key.values[v].value_num;
|
||||
sum_occur += curr_occur;
|
||||
}
|
||||
|
||||
LOG_INF("%s: key_offset = %zu, max_occur = %d, sum_occur = %d, slot_max = %d [%zu/%d, %zu/%d, %zu/%d, %zu/%d]\n", __func__,
|
||||
key_offset,
|
||||
max_occur, sum_occur, slot_max,
|
||||
curr_key.values[0].value_idx, curr_key.values[0].value_num,
|
||||
curr_key.values[1].value_idx, curr_key.values[1].value_num,
|
||||
curr_key.values[2].value_idx, curr_key.values[2].value_num,
|
||||
curr_key.values[3].value_idx, curr_key.values[3].value_num
|
||||
);
|
||||
// Print the tokens of the four values (if idx != 0), use LOG_INF
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
if (curr_key.values[v].value_idx != 0) {
|
||||
LOG_INF("%s: value[%d] = %s\n", __func__, v, common_tokens_to_str(inp, curr_key.values[v].value_idx, m).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (sum_occur > 0 && max_occur < 3 * sum_occur) {
|
||||
// The most frequent value is not much more frequent than the other values.
|
||||
// We do not use the draft.
|
||||
return;
|
||||
}
|
||||
|
||||
// We use the most frequent value values[slot_max] for the draft.
|
||||
// Fill in the draft with the m tokens following the key.
|
||||
int n_draft_tokens = std::min((int) m, (int) curr_key.values[slot_max].n_accepted);
|
||||
|
||||
for (int i = 0; i < n_draft_tokens; ++i) {
|
||||
draft.push_back(inp[match_pos + n + i]);
|
||||
}
|
||||
|
||||
LOG_INF("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__,
|
||||
key_offset, slot_max,
|
||||
curr_key.key_num, draft.size());
|
||||
|
||||
map.last_draft_created = true;
|
||||
map.last_draft_key_idx = key_offset;
|
||||
map.last_draft_value_idx = slot_max; // value used for draft generation.
|
||||
}
|
||||
|
||||
void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) {
|
||||
if (!map.last_draft_created) {
|
||||
return;
|
||||
}
|
||||
|
||||
// find the key and its chosen value.
|
||||
const size_t key_idx = map.last_draft_key_idx;
|
||||
const size_t val_idx = map.last_draft_value_idx;
|
||||
|
||||
// find key corresponding to key_idx.
|
||||
common_ngram_map_key & curr_key = map.keys[key_idx];
|
||||
// find value corresponding to val_idx.
|
||||
struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation.
|
||||
|
||||
// update the value statistics
|
||||
LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
|
||||
n_accepted, curr_value.n_accepted);
|
||||
curr_value.n_accepted = n_accepted;
|
||||
}
|
||||
106
common/ngram-map.h
Normal file
106
common/ngram-map.h
Normal file
@@ -0,0 +1,106 @@
|
||||
#pragma once
|
||||
//
|
||||
// common/ngram-map.h: structures used to manage a map from n-grams to a list of m-grams
|
||||
//
|
||||
// These structures are used to do a lookup of n-grams followed by m-grams in token history.
|
||||
//
|
||||
// There are two algorithms implemented:
|
||||
// 1. ngram_simple: lookup of n-grams followed by m-grams in token history.
|
||||
// 2. ngram_map: lookup of n-grams followed by m-grams in token history using a map.
|
||||
// The map is a vector of key n-grams, and for each key n-gram there is a list of value m-grams.
|
||||
//
|
||||
|
||||
#include "llama.h"
|
||||
#include "common.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
// n-gram simple
|
||||
//
|
||||
|
||||
// config of n-gram simple.
|
||||
struct common_ngram_simple_config {
|
||||
uint16_t size_ngram; // size of n-grams to lookup in self-mode
|
||||
uint16_t size_mgram; // size of m-grams to draft in self-mode
|
||||
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
|
||||
};
|
||||
|
||||
// current state (and config) of n-gram simple.
|
||||
struct common_ngram_simple_state {
|
||||
common_ngram_simple_config config;
|
||||
|
||||
size_t idx_last_check = 0; // index of last check in context history (mutable)
|
||||
|
||||
common_ngram_simple_state(const common_ngram_simple_config & config)
|
||||
: config(config) {}
|
||||
};
|
||||
|
||||
// Searches for a n-gram in the history and checks whether a draft sequence should be generated.
|
||||
// state: the ngram simple state to search in.
|
||||
// inp: the tokens generated so far.
|
||||
// sampled: the token that was just sampled.
|
||||
// draft: vector to store the draft tokens, initially empty.
|
||||
llama_tokens common_ngram_simple_draft(
|
||||
common_ngram_simple_state & state,
|
||||
const llama_tokens & tokens, llama_token sampled);
|
||||
|
||||
|
||||
// n-gram map
|
||||
//
|
||||
|
||||
// maximum number of m-gram values stored for each key n-gram.
|
||||
#define COMMON_NGRAM_MAX_VALUES 4
|
||||
|
||||
// statistics of a m-gram after a known n-gram
|
||||
struct common_ngram_map_value {
|
||||
size_t value_idx = 0; // index of value m-gram in token-history (0 if unused)
|
||||
uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot)
|
||||
int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused)
|
||||
};
|
||||
|
||||
// statistics of a n-gram
|
||||
struct common_ngram_map_key {
|
||||
size_t key_idx; // index of key n-gram in token-history
|
||||
size_t stat_idx; // index of last token of stastistics computation (key_num, values)
|
||||
|
||||
uint16_t key_num; // number of occurences of this key n-gram in token-history
|
||||
common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key
|
||||
};
|
||||
|
||||
// map from n-grams to following m-grams in token-history
|
||||
struct common_ngram_map {
|
||||
uint16_t size_key; // size of key n-grams
|
||||
uint16_t size_value; // size of value m-grams
|
||||
|
||||
bool key_only; // true if only key n-grams are used, no values.
|
||||
|
||||
// first draft: vector only, no map.
|
||||
std::vector<common_ngram_map_key> keys; // key n-grams which occur several times in token-history
|
||||
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
|
||||
uint16_t min_hits; // minimum number of key hits to consider a draft
|
||||
|
||||
common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys,
|
||||
uint16_t check_rate, uint16_t min_hits)
|
||||
: size_key(sz_key), size_value(sz_value), key_only(only_keys),
|
||||
check_rate(check_rate), min_hits(min_hits) {}
|
||||
|
||||
bool last_draft_created = false; // true if a draft was created at last call.
|
||||
size_t last_draft_key_idx = 0; // index of last key used for draft generation.
|
||||
uint16_t last_draft_value_idx = 0; // index of last value used for draft generation.
|
||||
|
||||
size_t idx_last_check = 0; // index of last check in context history
|
||||
};
|
||||
|
||||
|
||||
// Searches for the n-gram in the history and checks whether a draft sequence should be generated.
|
||||
// map: the ngram map to search in.
|
||||
// inp: the tokens generated so far.
|
||||
// sampled: the token that was just sampled.
|
||||
// draft: vector to store the draft tokens, initially empty.
|
||||
void common_ngram_map_draft(
|
||||
common_ngram_map & map,
|
||||
const llama_tokens & inp, llama_token sampled,
|
||||
llama_tokens & draft);
|
||||
|
||||
// Update the statistics of a value after a draft was processed.
|
||||
void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted);
|
||||
60
common/ngram-mod.cpp
Normal file
60
common/ngram-mod.cpp
Normal file
@@ -0,0 +1,60 @@
|
||||
#include "ngram-mod.h"
|
||||
|
||||
//
|
||||
// common_ngram_mod
|
||||
//
|
||||
|
||||
common_ngram_mod::common_ngram_mod(uint16_t n, size_t size) : n(n), used(0) {
|
||||
entries.resize(size);
|
||||
|
||||
reset();
|
||||
}
|
||||
|
||||
size_t common_ngram_mod::idx(const entry_t * tokens) const {
|
||||
size_t res = 0;
|
||||
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
res = res*6364136223846793005ULL + tokens[i];
|
||||
}
|
||||
|
||||
res = res % entries.size();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void common_ngram_mod::add(const entry_t * tokens) {
|
||||
const size_t i = idx(tokens);
|
||||
|
||||
if (entries[i] == EMPTY) {
|
||||
used++;
|
||||
}
|
||||
|
||||
entries[i] = tokens[n];
|
||||
}
|
||||
|
||||
common_ngram_mod::entry_t common_ngram_mod::get(const entry_t * tokens) const {
|
||||
const size_t i = idx(tokens);
|
||||
|
||||
return entries[i];
|
||||
}
|
||||
|
||||
void common_ngram_mod::reset() {
|
||||
std::fill(entries.begin(), entries.end(), EMPTY);
|
||||
used = 0;
|
||||
}
|
||||
|
||||
size_t common_ngram_mod::get_n() const {
|
||||
return n;
|
||||
}
|
||||
|
||||
size_t common_ngram_mod::get_used() const {
|
||||
return used;
|
||||
}
|
||||
|
||||
size_t common_ngram_mod::size() const {
|
||||
return entries.size();
|
||||
}
|
||||
|
||||
size_t common_ngram_mod::size_bytes() const {
|
||||
return entries.size() * sizeof(entries[0]);
|
||||
}
|
||||
38
common/ngram-mod.h
Normal file
38
common/ngram-mod.h
Normal file
@@ -0,0 +1,38 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <cstddef>
|
||||
|
||||
//
|
||||
// common_ngram_mod
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/19164
|
||||
//
|
||||
|
||||
// basic n-gram hasher
|
||||
struct common_ngram_mod {
|
||||
using entry_t = int32_t;
|
||||
|
||||
static constexpr entry_t EMPTY = -1;
|
||||
|
||||
common_ngram_mod(uint16_t n, size_t size);
|
||||
|
||||
size_t idx(const entry_t * tokens) const;
|
||||
void add(const entry_t * tokens);
|
||||
entry_t get(const entry_t * tokens) const; // return -1 if not found
|
||||
|
||||
void reset();
|
||||
|
||||
size_t get_n() const;
|
||||
size_t get_used() const;
|
||||
|
||||
size_t size() const;
|
||||
size_t size_bytes() const;
|
||||
|
||||
private:
|
||||
size_t n; // ngram size to hash
|
||||
|
||||
size_t used;
|
||||
|
||||
std::vector<entry_t> entries;
|
||||
};
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,31 +5,33 @@
|
||||
|
||||
struct common_speculative;
|
||||
|
||||
struct common_speculative_params {
|
||||
int n_draft = 16; // max drafted tokens
|
||||
int n_reuse = 256;
|
||||
// comma separated list of all types
|
||||
std::string common_speculative_type_name_str();
|
||||
|
||||
float p_min = 0.75f; // min probability required to accept a token in the draft
|
||||
};
|
||||
// convert string to type
|
||||
enum common_speculative_type common_speculative_type_from_name(const std::string & name);
|
||||
|
||||
struct common_speculative * common_speculative_init(
|
||||
struct llama_context * ctx_tgt,
|
||||
struct llama_context * ctx_dft
|
||||
);
|
||||
// convert type to string
|
||||
std::string common_speculative_type_to_str(enum common_speculative_type type);
|
||||
|
||||
void common_speculative_free(struct common_speculative * spec);
|
||||
common_speculative * common_speculative_init(
|
||||
common_params_speculative & params,
|
||||
llama_context * ctx_tgt);
|
||||
|
||||
bool common_speculative_are_compatible(
|
||||
const struct llama_context * ctx_tgt,
|
||||
const struct llama_context * ctx_dft);
|
||||
void common_speculative_free(common_speculative * spec);
|
||||
|
||||
void common_speculative_add_replacement_tgt_dft(
|
||||
struct common_speculative * spec,
|
||||
const char *source, const char *dest);
|
||||
// optionally call once at the beginning of a new generation
|
||||
void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt);
|
||||
|
||||
// sample up to n_draft tokens and add them to the batch using the draft model
|
||||
llama_tokens common_speculative_gen_draft(
|
||||
struct common_speculative * spec,
|
||||
struct common_speculative_params params,
|
||||
const llama_tokens & prompt,
|
||||
llama_token id_last);
|
||||
llama_tokens common_speculative_draft(
|
||||
common_speculative * spec,
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt,
|
||||
llama_token id_last);
|
||||
|
||||
// informs the speculative decoder that n_accepted tokens were accepted by the target model
|
||||
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
|
||||
|
||||
// print statistics about the speculative decoding
|
||||
void common_speculative_print_stats(const common_speculative * spec);
|
||||
|
||||
@@ -8806,6 +8806,7 @@ class GraniteMoeModel(GraniteModel):
|
||||
gate, up = data_torch.split(ffn_dim, dim=-2)
|
||||
yield from ModelBase.modify_tensors(self, gate, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), bid)
|
||||
yield from ModelBase.modify_tensors(self, up, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), bid)
|
||||
return
|
||||
|
||||
has_experts = bool(self.hparams.get('num_local_experts'))
|
||||
|
||||
@@ -8912,13 +8913,16 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
|
||||
name.endswith("block_sparse_moe.input_linear.weight")
|
||||
or "shared_mlp" in name
|
||||
):
|
||||
return GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
|
||||
yield from GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
|
||||
return
|
||||
|
||||
# Determine whether this is a mamba layer or an attention layer
|
||||
if bid in self._ssm_layers:
|
||||
return Mamba2Model.modify_tensors(self, data_torch, name, bid)
|
||||
yield from Mamba2Model.modify_tensors(self, data_torch, name, bid)
|
||||
return
|
||||
elif bid in self._attn_layers:
|
||||
return GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
|
||||
yield from GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
|
||||
return
|
||||
yield from ModelBase.modify_tensors(self, data_torch, name, bid)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
|
||||
@@ -35,9 +35,9 @@ The following releases are verified and recommended:
|
||||
|
||||
|Commit ID|Tag|Release|Verified Platform| Update date|
|
||||
|-|-|-|-|-|
|
||||
|24e86cae7219b0f3ede1d5abdf5bf3ad515cccb8|b5377 |[llama-b5377-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b5377/llama-b5377-bin-win-sycl-x64.zip) |ArcB580/Linux/oneAPI 2025.1<br>LNL Arc GPU/Windows 11/oneAPI 2025.1.1|2025-05-15|
|
||||
|3bcd40b3c593d14261fb2abfabad3c0fb5b9e318|b4040 |[llama-b4040-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b4040/llama-b4040-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1| 2024-11-19|
|
||||
|fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1||
|
||||
|24e86cae7219b0f3ede1d5abdf5bf3ad515cccb8|b5377 |[llama-b5377-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b5377/llama-b5377-bin-win-sycl-x64.zip) |Arc B580/Linux/oneAPI 2025.1<br>LNL Arc GPU/Windows 11/oneAPI 2025.1.1|2025-05-15|
|
||||
|3bcd40b3c593d14261fb2abfabad3c0fb5b9e318|b4040 |[llama-b4040-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b4040/llama-b4040-bin-win-sycl-x64.zip) |Arc A770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1| 2024-11-19|
|
||||
|fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc A770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1||
|
||||
|
||||
|
||||
## News
|
||||
@@ -51,7 +51,7 @@ The following releases are verified and recommended:
|
||||
|-|-|-|-|
|
||||
|PVC 1550|39|73|+87%|
|
||||
|Flex 170|39|50|+28%|
|
||||
|Arc770|42|55|+30%|
|
||||
|Arc A770|42|55|+30%|
|
||||
|MTL|13|16|+23%|
|
||||
|ARL-H|14|17|+21%|
|
||||
|
||||
@@ -62,7 +62,7 @@ The following releases are verified and recommended:
|
||||
- Use oneDNN as the default GEMM library, improve the compatibility for new Intel GPUs.
|
||||
|
||||
- 2024.5
|
||||
- Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc770.
|
||||
- Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc A770.
|
||||
- Arch Linux is verified successfully.
|
||||
|
||||
- 2024.4
|
||||
@@ -111,7 +111,8 @@ On older Intel GPUs, you may try [OpenCL](/docs/backend/OPENCL.md) although the
|
||||
|-------------------------------|---------|---------------------------------------|
|
||||
| Intel Data Center Max Series | Support | Max 1550, 1100 |
|
||||
| Intel Data Center Flex Series | Support | Flex 170 |
|
||||
| Intel Arc Series | Support | Arc 770, 730M, Arc A750, B580 |
|
||||
| Intel Arc A-Series | Support | Arc A770, Arc A730M, Arc A750 |
|
||||
| Intel Arc B-Series | Support | Arc B580 |
|
||||
| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake, Arrow Lake, Lunar Lake |
|
||||
| Intel iGPU | Support | iGPU in 13700k, 13400, i5-1250P, i7-1260P, i7-1165G7 |
|
||||
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
{
|
||||
"version": 4,
|
||||
"version": 5,
|
||||
"cmakeMinimumRequired": {
|
||||
"major": 3,
|
||||
"minor": 28,
|
||||
"patch": 0
|
||||
},
|
||||
"configurePresets": [
|
||||
{
|
||||
"name": "arm64-android-snapdragon",
|
||||
@@ -16,7 +21,9 @@
|
||||
"CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
"CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
|
||||
"CMAKE_PREFIX_PATH": "$env{OPENCL_SDK_ROOT}",
|
||||
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
|
||||
"HEXAGON_TOOLS_ROOT": "$env{HEXAGON_TOOLS_ROOT}",
|
||||
"PREBUILT_LIB_DIR": "android_aarch64",
|
||||
"GGML_OPENMP": "OFF",
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
@@ -31,7 +38,15 @@
|
||||
"name": "arm64-windows-snapdragon",
|
||||
"inherits": [ "base", "arm64-windows-llvm" ],
|
||||
"cacheVariables": {
|
||||
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
|
||||
"CMAKE_C_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -flto -D_GNU_SOURCE",
|
||||
"CMAKE_CXX_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -flto -D_GNU_SOURCE",
|
||||
"CMAKE_C_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
"CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
"CMAKE_PREFIX_PATH": "$env{OPENCL_SDK_ROOT}",
|
||||
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
|
||||
"HEXAGON_TOOLS_ROOT": "$env{HEXAGON_TOOLS_ROOT}",
|
||||
"PREBUILT_LIB_DIR": "windows_aarch64",
|
||||
"GGML_OPENMP": "OFF",
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
@@ -1,6 +1,8 @@
|
||||
# Snapdragon-based Android devices
|
||||
# Snapdragon-based devices
|
||||
|
||||
## How to Build
|
||||
## Setup
|
||||
|
||||
### Android
|
||||
|
||||
The easiest way to build llama.cpp for a Snapdragon-based Android device is using the toolchain Docker image (see github.com/snapdragon-toolchain).
|
||||
This image includes Android NDK, OpenCL SDK, Hexagon SDK, CMake, etc.
|
||||
@@ -12,7 +14,24 @@ This method works on Linux, macOS, and Windows. macOS and Windows users should i
|
||||
[d]/> cd /workspace
|
||||
```
|
||||
|
||||
The rest of the Android build process assumes that you're running inside the toolchain container.
|
||||
Note: The rest of the **Android** build process assumes that you're running inside the toolchain container.
|
||||
|
||||
### Windows On Snapdragon
|
||||
|
||||
Native Windows 11 arm64 builds has the following tools dependencies:
|
||||
- MS Visual Studio 2026 (Community Edition or Pro)
|
||||
- MSVC arm64 standard and runtime libraries
|
||||
- UCRT and Driver Kit
|
||||
- LLVM core libraries and Clang compiler (winget)
|
||||
- CMake, Git, Python (winget)
|
||||
- Hexagon SDK Community Edition 6.4 or later (see windows.md)
|
||||
- OpenCL SDK 2.3 or later (see windows.md)
|
||||
|
||||
Note: The rest of the **Windows** build process assumes that you're running natively in Powershell.
|
||||
Adapt below build commands accordingly.
|
||||
|
||||
## How to Build
|
||||
|
||||
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
|
||||
|
||||
```
|
||||
@@ -49,24 +68,26 @@ Preset CMake variables:
|
||||
To generate an installable "package" simply use cmake --install:
|
||||
|
||||
```
|
||||
[d]/workspace> cmake --install build-snapdragon --prefix pkg-adb/llama.cpp
|
||||
[d]/workspace> cmake --install build-snapdragon --prefix pkg-snapdragon/llama.cpp
|
||||
-- Install configuration: "Release"
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-cpu.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-opencl.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-hexagon.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v73.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v75.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v79.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v81.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-cpu.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-opencl.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-hexagon.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v73.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v75.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v79.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v81.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml.so
|
||||
...
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/bin/llama-bench
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/bin/llama-cli
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/bin/llama-bench
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/bin/llama-cli
|
||||
...
|
||||
```
|
||||
|
||||
## How to Install
|
||||
|
||||
### Android
|
||||
|
||||
For this step, your device needs to be configured for on-device development.
|
||||
Please see https://developer.android.com/studio/debug/dev-options for details.
|
||||
|
||||
@@ -74,10 +95,10 @@ Once ADB is enabled, use `adb push` to install `pkg-snapdragon` on the device.
|
||||
**Note that the toolchain Docker image doesn't have ADB and doesn't set up the ADB bridge. Please use native ADB on the host.**
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ adb push pkg-adb/llama.cpp /data/local/tmp/
|
||||
pkg-adb/llama.cpp/bin/: 67 files pushed, 0 skipped. 190.2 MB/s (919095042 bytes in 4.607s)
|
||||
pkg-adb/llama.cpp/include/: 19 files pushed, 0 skipped. 20.5 MB/s (255173 bytes in 0.012s)
|
||||
pkg-adb/llama.cpp/lib/: 16 files pushed, 0 skipped. 144.4 MB/s (43801382 bytes in 0.289s)
|
||||
~/src/llama.cpp$ adb push pkg-snapdragon/llama.cpp /data/local/tmp/
|
||||
pkg-snapdragon/llama.cpp/bin/: 67 files pushed, 0 skipped. 190.2 MB/s (919095042 bytes in 4.607s)
|
||||
pkg-snapdragon/llama.cpp/include/: 19 files pushed, 0 skipped. 20.5 MB/s (255173 bytes in 0.012s)
|
||||
pkg-snapdragon/llama.cpp/lib/: 16 files pushed, 0 skipped. 144.4 MB/s (43801382 bytes in 0.289s)
|
||||
102 files pushed, 0 skipped. 186.9 MB/s (963151597 bytes in 4.914s)
|
||||
```
|
||||
|
||||
@@ -92,6 +113,11 @@ At this point, you should also install some models:
|
||||
Llama-3.2-1B-Instruct-Q4_0.gguf: 1 file pushed, 0 skipped. 38.3 MB/s (773025920 bytes in 19.250s)
|
||||
```
|
||||
|
||||
### Windows
|
||||
|
||||
All artifacts are already installed in the `pkg-snapdragon` folder.
|
||||
To run, adapt below instructions to use Powershell scrits in `scripts/snapdragon/windows`.
|
||||
|
||||
## How to Run
|
||||
|
||||
The easiest way to run llama.cpp cli tools is using provided wrapper scripts that properly set up all required environment variables.
|
||||
161
docs/backend/snapdragon/windows.md
Normal file
161
docs/backend/snapdragon/windows.md
Normal file
@@ -0,0 +1,161 @@
|
||||
## Overview
|
||||
|
||||
The document covers procedures for installing the latest GPU and NPU drivers, and OpenCL and Hexagon SDKs.
|
||||
|
||||
|
||||
In order to use Hexagon NPU on Snapdragon Windows devices the underlying HTP Ops libraries (e.g libggml-htp-v73.so)
|
||||
must be included in the .cat file digitally signed with a trusted certificate.
|
||||
|
||||
This document covers details on how to generate personal certificate files (.pfx) and how to configure the system
|
||||
to allow for test signatures (aka test-signing).
|
||||
|
||||
## Install the latest Adreno OpenCL SDK
|
||||
|
||||
Either use the trimmed down version (optimized for CI) from
|
||||
|
||||
https://github.com/snapdragon-toolchain/opencl-sdk/releases/download/v2.3.2/adreno-opencl-sdk-v2.3.2-arm64-wos.tar.xz
|
||||
|
||||
Or download the complete official version from
|
||||
|
||||
https://softwarecenter.qualcomm.com/catalog/item/Adreno_OpenCL_SDK?version=2.3.2
|
||||
|
||||
Unzip/untar the archive into
|
||||
```
|
||||
c:\Qualcomm\OpenCL_SDK\2.3.2
|
||||
```
|
||||
|
||||
## Install the latest Hexagon SDK Community Edition
|
||||
|
||||
Either use the trimmed down version (optimized for CI) from
|
||||
|
||||
https://github.com/snapdragon-toolchain/hexagon-sdk/releases/download/v6.4.0.2/hexagon-sdk-v6.4.0.2-arm64-wos.tar.xz
|
||||
|
||||
Or download the complete official version from
|
||||
|
||||
https://softwarecenter.qualcomm.com/catalog/item/Hexagon_SDK?version=6.4.0.2
|
||||
|
||||
Unzip/untar the archive into
|
||||
```
|
||||
c:\Qualcomm\Hexagon_SDK\6.4.0.2
|
||||
```
|
||||
|
||||
## Install the latest Adreno GPU driver
|
||||
|
||||
Download the driver from
|
||||
|
||||
https://softwarecenter.qualcomm.com/catalog/item/Windows_Graphics_Driver
|
||||
|
||||
After the automated installation and reboot please make sure that the GPU device shows up in the `Device Manager` (under 'Display Adapters`)
|
||||
|
||||
## Install the latest Qualcomm NPU driver
|
||||
|
||||
Download the driver from
|
||||
|
||||
https://softwarecenter.qualcomm.com/catalog/item/Qualcomm_HND
|
||||
|
||||
After the automated installation and reboot please make sure that the Hexagon NPU device shows up in the `Device Manager` (under `Neural Processors`).
|
||||
|
||||
If the device is not available you can try installing all components (`qcnspmcdm8380`, `qcnspmcdm8380_ext`) manually.
|
||||
The components are extracted into
|
||||
```
|
||||
c:\QCDrivers\qcnspmcdm...
|
||||
```
|
||||
|
||||
## Enable NPU driver test signatures
|
||||
|
||||
Please note that the following steps are required only for the Hexagon NPU.
|
||||
Adreno GPU backend does not require test signatures.
|
||||
|
||||
### Enable testsigning
|
||||
|
||||
Use `bcdedit` to enable test-signing
|
||||
```
|
||||
> bcdedit /set TESTSIGNING ON
|
||||
```
|
||||
(Secure Boot may need to be disabled for this to work)
|
||||
|
||||
Make sure test-signing is enabled after reboot
|
||||
```
|
||||
> bcdedit /enum
|
||||
...
|
||||
testsigning Yes
|
||||
...
|
||||
```
|
||||
For additional details see Microsoft guide at
|
||||
|
||||
https://learn.microsoft.com/en-us/windows-hardware/drivers/install/the-testsigning-boot-configuration-option
|
||||
|
||||
### Create personal certificate
|
||||
|
||||
The tools required for this procedure are available as part of Windows SDK and Windows Driver Kit which should be
|
||||
installed as part of the MS Visual Studio.
|
||||
They are typically located at
|
||||
```
|
||||
c:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0
|
||||
```
|
||||
(replace 10.0.26100.0 with correct version).
|
||||
|
||||
To create personal self-signed certificate run the following commands (either from cmd or power-shell):
|
||||
```
|
||||
> cd c:\Users\MyUser
|
||||
> mkdir Certs
|
||||
> cd Certs
|
||||
> makecert -r -pe -ss PrivateCertStore -n CN=GGML.HTP.v1 -eku 1.3.6.1.5.5.7.3.3 -sv ggml-htp-v1.pvk ggml-htp-v1.cer
|
||||
> pvk2pfx.exe -pvk ggml-htp-v1.pvk -spc ggml-htp-v1.cer -pfx ggml-htp-v1.pfx
|
||||
```
|
||||
(replace `MyUser` with your username).
|
||||
|
||||
Add this certificate to `Trusted Root Certification Authorities` and `Trusted Publishers` stores.
|
||||
This can be done using `certlm` Certificate Manager tool.
|
||||
Right click on the certificate store, select `All Tasks -> Import` and follow the prompts to import the certificate from the
|
||||
PFX file you created above.
|
||||
|
||||
For additional details see Microsoft guide at
|
||||
|
||||
https://learn.microsoft.com/en-us/windows-hardware/drivers/install/introduction-to-test-signing
|
||||
|
||||
Make sure to save the PFX file, you will need it for the build procedures.
|
||||
Please note that the same certificate can be used for signing any number of builds.
|
||||
|
||||
## Build Hexagon backend with signed HTP ops libraries
|
||||
|
||||
The overall Hexagon backend build procedure for Windows on Snapdragon is the same as for other platforms.
|
||||
However, additional settings are required for generating and signing HTP Ops libraries.
|
||||
```
|
||||
> $env:OPENCL_SDK_ROOT="C:\Qualcomm\OpenCL_SDK\2.3.2"
|
||||
> $env:HEXAGON_SDK_ROOT="C:\Qualcomm\Hexagon_SDK\6.4.0.2"
|
||||
> $env:HEXAGON_TOOLS_ROOT="C:\Qualcomm\Hexagon_SDK\6.4.0.2\tools\HEXAGON_Tools\19.0.04"
|
||||
> $env:HEXAGON_HTP_CERT="c:\Users\MyUsers\Certs\ggml-htp-v1.pfx"
|
||||
> $env:WINDOWS_SDK_BIN="C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\arm64"
|
||||
|
||||
> cmake --preset arm64-windows-snapdragon -B build-wos
|
||||
...
|
||||
> cmake --install build-wos --prefix pkg-snapdragon
|
||||
```
|
||||
|
||||
Once the build is complete HTP ops libraries will be installed like this
|
||||
```
|
||||
> dir pkg-snapdragon/lib
|
||||
...
|
||||
-a---- 1/22/2026 6:01 PM 187656 libggml-htp-v73.so
|
||||
-a---- 1/22/2026 6:01 PM 191752 libggml-htp-v75.so
|
||||
-a---- 1/22/2026 6:01 PM 187656 libggml-htp-v79.so
|
||||
-a---- 1/22/2026 6:01 PM 187656 libggml-htp-v81.so
|
||||
-a---- 1/22/2026 6:01 PM 4139 libggml-htp.cat
|
||||
```
|
||||
|
||||
The .cat file, the signature and proper certicate installation can be verified with
|
||||
|
||||
```
|
||||
> signtool.exe verify /v /pa .\pkg-snapdragon\lib\libggml-htp.cat
|
||||
Verifying: .\pkg-snapdragon\lib\libggml-htp.cat
|
||||
|
||||
Signature Index: 0 (Primary Signature)
|
||||
Hash of file (sha256): 9820C664DA59D5EAE31DBB664127FCDAEF59CDC31502496BC567544EC2F401CF
|
||||
|
||||
Signing Certificate Chain:
|
||||
Issued to: GGML.HTP.v1
|
||||
...
|
||||
Successfully verified: .\pkg-snapdragon\lib\libggml-htp.cat
|
||||
...
|
||||
```
|
||||
@@ -144,7 +144,7 @@ We also have a [guide](./backend/CUDA-FEDORA.md) for setting up CUDA toolkit in
|
||||
- ***Necessary*** for users of [Atomic Desktops for Fedora](https://fedoraproject.org/atomic-desktops/); such as: [Silverblue](https://fedoraproject.org/atomic-desktops/silverblue/) and [Kinoite](https://fedoraproject.org/atomic-desktops/kinoite/).
|
||||
- (there are no supported CUDA packages for these systems)
|
||||
- ***Necessary*** for users that have a host that is not a: [Supported Nvidia CUDA Release Platform](https://developer.nvidia.com/cuda-downloads).
|
||||
- (for example, you may have [Fedora 42 Beta](https://fedoramagazine.org/announcing-fedora-linux-42-beta/) as your your host operating system)
|
||||
- (for example, you may have [Fedora 42 Beta](https://fedoramagazine.org/announcing-fedora-linux-42-beta/) as your host operating system)
|
||||
- ***Convenient*** For those running [Fedora Workstation](https://fedoraproject.org/workstation/) or [Fedora KDE Plasma Desktop](https://fedoraproject.org/spins/kde), and want to keep their host system clean.
|
||||
- *Optionally* toolbox packages are available: [Arch Linux](https://archlinux.org/), [Red Hat Enterprise Linux >= 8.5](https://www.redhat.com/en/technologies/linux-platforms/enterprise-linux), or [Ubuntu](https://ubuntu.com/download)
|
||||
|
||||
@@ -495,6 +495,37 @@ Finally, after finishing your build, you should be able to do something like thi
|
||||
# ggml_vulkan: Using Intel(R) Graphics (ADL GT2) | uma: 1 | fp16: 1 | warp size: 32
|
||||
```
|
||||
|
||||
### For Mac users:
|
||||
|
||||
Generally, follow LunarG's [Getting Started with the MacOS Vulkan SDK](https://vulkan.lunarg.com/doc/sdk/latest/mac/getting_started.html) guide for installation and setup of the Vulkan SDK. There are two options of Vulkan drivers on macOS, both of which implement translation layers to map Vulkan to Metal. They can be hot-swapped by setting the `VK_ICD_FILENAMES` environment variable to point to the respective ICD JSON file.
|
||||
|
||||
Check the box for "KosmicKrisp" during the LunarG Vulkan SDK installation.
|
||||
|
||||
Set environment variable for the LunarG Vulkan SDK after installation (and optionally add to your shell profile for persistence):
|
||||
```bash
|
||||
source /path/to/vulkan-sdk/setup-env.sh
|
||||
```
|
||||
|
||||
#### Using MoltenVK
|
||||
|
||||
MoltenVK is the default Vulkan driver installed with the LunarG Vulkan SDK on macOS, so you can use the above environment variable settings as is.
|
||||
|
||||
#### Using KosmicKrisp
|
||||
|
||||
Override the environment variable for KosmicKrisp:
|
||||
```bash
|
||||
export VK_ICD_FILENAMES=$VULKAN_SDK/share/vulkan/icd.d/libkosmickrisp_icd.json
|
||||
export VK_DRIVER_FILES=$VULKAN_SDK/share/vulkan/icd.d/libkosmickrisp_icd.json
|
||||
```
|
||||
|
||||
#### Build
|
||||
|
||||
This is the only step different from [above](#common-steps) instructions.
|
||||
```bash
|
||||
cmake -B build -DGGML_VULKAN=1 -DGGML_METAL=OFF
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
## CANN
|
||||
This provides NPU acceleration using the AI cores of your Ascend NPU. And [CANN](https://www.hiascend.com/en/software/cann) is a hierarchical APIs to help you to quickly build AI applications and service based on Ascend NPU.
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ Legend:
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
@@ -114,7 +114,7 @@ Legend:
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
||||
1368
docs/ops/SYCL.csv
1368
docs/ops/SYCL.csv
File diff suppressed because it is too large
Load Diff
120
docs/speculative.md
Normal file
120
docs/speculative.md
Normal file
@@ -0,0 +1,120 @@
|
||||
# Speculative Decoding
|
||||
|
||||
llama.cpp supports speculative decoding, a technique that can significantly accelerate token generation by predicting multiple tokens ahead of the main model.
|
||||
|
||||
[Speculative decoding](https://en.wikipedia.org/wiki/Transformer_(deep_learning)#Speculative_decoding) leverages the fact that computing n tokens in a batch (as in prompt processing) is more efficient than computing n sequentially (as in response generation). By generating draft tokens quickly and then verifying them with the target model in a single batch, this approach can achieve substantial speedups when the draft predictions are frequently correct.
|
||||
|
||||
## Implementations
|
||||
|
||||
The `llama-server` application supports several implementations of speculative decoding:
|
||||
|
||||
### Draft Model (`draft`)
|
||||
|
||||
A much smaller model (called the _draft model_) generates drafts.
|
||||
A draft model is the most used approach in speculative decoding.
|
||||
|
||||
### n-gram Cache (`ngram-cache`)
|
||||
|
||||
An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences.
|
||||
A draft is computed using probabilities derived from these statistics. External statistics can also be loaded from files for improved accuracy.
|
||||
|
||||
See:
|
||||
|
||||
- #5479, #6828, #6848
|
||||
|
||||
### n-gram Map (`ngram-simple`, `ngram-map-*`)
|
||||
|
||||
These implementations search the token history for patterns and use matching sequences as draft candidates.
|
||||
They require no additional model but rely on patterns that have already appeared in the generated text.
|
||||
An example to use this approach can be the rewriting of source code by a LLM.
|
||||
|
||||
#### n-gram Map (`ngram-simple`)
|
||||
|
||||
This implementation looks for the last n-gram in history that matches the current n-gram and creates a draft using the m tokens following the matched n-gram. It is the simplest self-speculative approach with minimal overhead.
|
||||
|
||||
#### n-gram Map Key (`ngram-map-k`)
|
||||
|
||||
This implementation looks for the current n-gram of size n (called the _key_) in the token history. If the key n-gram is followed by the same m tokens (called the _mgram_) multiple times, it creates a draft using these m tokens. This approach requires a minimum number of occurrences (argument `--spec-ngram-min-hits`) before generating drafts.
|
||||
|
||||
The number of accepted tokens is stored for each used n-gram.
|
||||
|
||||
#### n-gram Map Key-4-Values (`ngram-map-k4v`)
|
||||
|
||||
This experimental implementation looks for the current n-gram of size n (called the _key_) in the token history. For each key, up to four _values_ (n-grams of size m, called _mgrams_) are tracked. An internal statistic counts the occurrences of each mgram after the key n-gram. If one mgram is significantly more frequent than the others, it is used as the draft.
|
||||
|
||||
The number of accepted tokens is stored for each used n-gram.
|
||||
|
||||
**Example:** Server options to be used if there are a lot of longer repetitions.
|
||||
```bash
|
||||
llama-server [...] --spec-type ngram-map-k4v --spec-ngram-size-n 8 --spec-ngram-size-m 8 --spec-ngram-min-hits 2
|
||||
```
|
||||
|
||||
|
||||
## Command-Line Options
|
||||
|
||||
If a draft model is combined with a draftless decoding the draftless decoding has higher precedence.
|
||||
|
||||
```
|
||||
--spec-type [none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v]
|
||||
type of speculative decoding to use when no draft model is provided
|
||||
(default: none)
|
||||
--spec-ngram-size-n N ngram size N for ngram-simple/ngram-map speculative decoding, length
|
||||
of lookup n-gram (default: 12)
|
||||
--spec-ngram-size-m N ngram size M for ngram-simple/ngram-map speculative decoding, length
|
||||
of draft m-gram (default: 48)
|
||||
--spec-ngram-check-rate N ngram check rate for ngram-simple/ngram-map speculative decoding
|
||||
(default: 1)
|
||||
--spec-ngram-min-hits N minimum hits for ngram-map speculative decoding (default: 1)
|
||||
```
|
||||
|
||||
### `--spec-type TYPE`
|
||||
|
||||
Specifies a type of speculative decoding without draft model.
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `none` | No speculative decoding (default) |
|
||||
| `ngram-cache` | Use n-gram cache lookup |
|
||||
| `ngram-simple` | Use simple n-gram pattern matching |
|
||||
| `ngram-map-k` | Use n-gram pattern matching with n-gram-keys |
|
||||
| `ngram-map-k4v` | Use n-gram pattern matching with n-gram-keys and up to four m-gram values (experimental) |
|
||||
|
||||
**Example:** Server-instance used to refactor source code.
|
||||
```bash
|
||||
./llama-server [...] --spec-type ngram-simple
|
||||
```
|
||||
|
||||
### `--spec-ngram-size-n N`
|
||||
|
||||
Sets the size N of the lookup n-gram for n-gram map based speculative decoding.
|
||||
The n-gram size N determines how many tokens in a row to look back when searching for matching patterns.
|
||||
|
||||
### `--spec-ngram-size-m M`
|
||||
|
||||
Sets the size M of the draft m-gram for n-gram map based speculative decoding.
|
||||
The m-gram size determines how many tokens to draft when a match is found.
|
||||
Larger values can provide more speedup but may reduce acceptance rate.
|
||||
|
||||
### `--spec-ngram-check-rate R`
|
||||
|
||||
This option aims at performance if the n-gram lookup in history is to costly. A lookup will be executed at every R tokens (default is 1, every token).
|
||||
|
||||
### `--spec-ngram-min-hits H`
|
||||
|
||||
This option defines how often a key has to appear in the token history to be used as a draft (default is 1).
|
||||
|
||||
## Statistics
|
||||
Each speculative decoding implementation prints statistics.
|
||||
|
||||
```
|
||||
draft acceptance rate = 0.57576 ( 171 accepted / 297 generated)
|
||||
statistics ngram_simple: #calls = 15, #gen drafts = 5, #acc drafts = 5, #gen tokens = 187, #acc tokens = 73
|
||||
statistics draft: #calls = 10, #gen drafts = 10, #acc drafts = 10, #gen tokens = 110, #acc tokens = 98
|
||||
```
|
||||
|
||||
- `#calls`: number of calls of this implementations
|
||||
- `#gen drafts`: number of drafts generated by this implementation
|
||||
- `#acc drafts`: number of drafts accepted (partially) by the main model
|
||||
- `#gen tokens`: number of tokens generated by this implementation (including rejected tokens)
|
||||
- `#acc tokens`: number of tokens accepted by the main model
|
||||
|
||||
2
examples/compare-mlx/.gitignore
vendored
Normal file
2
examples/compare-mlx/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*.txt
|
||||
*/
|
||||
3
examples/compare-mlx/cmd.sh
Executable file
3
examples/compare-mlx/cmd.sh
Executable file
@@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
./compare-mlx.sh --raw-path ../../build/wikitext-2-raw/wiki.test.raw --no-keep
|
||||
707
examples/compare-mlx/compare-mlx.sh
Executable file
707
examples/compare-mlx/compare-mlx.sh
Executable file
@@ -0,0 +1,707 @@
|
||||
#!/bin/bash
|
||||
|
||||
# a script to compare MLX and GGUF models
|
||||
#
|
||||
# usage:
|
||||
# ./compare-mlx.sh --raw-path wiki.test.raw --no-keep
|
||||
#
|
||||
# TODOs
|
||||
# - add QAT evals
|
||||
|
||||
# check if LLAMA_HOME_DIR is set
|
||||
if [[ -z "$LLAMA_HOME_DIR" ]]; then
|
||||
lcpp_dir=$(cd "$(dirname "${BASH_SOURCE[0]}")"/../../ && pwd)
|
||||
else
|
||||
lcpp_dir="${LLAMA_HOME_DIR}"
|
||||
fi
|
||||
|
||||
echo "Using llama.cpp directory: ${lcpp_dir}"
|
||||
|
||||
# check for convert_hf_to_gguf.py
|
||||
if [[ ! -f "${lcpp_dir}/convert_hf_to_gguf.py" ]]; then
|
||||
echo "convert_hf_to_gguf.py not found in ${lcpp_dir}"
|
||||
echo "Set a LLAMA_HOME_DIR environment variable to point to your llama.cpp directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
set -x
|
||||
|
||||
# sanity checks that all Python dependencies are installed
|
||||
if ! python -c "import mlx.core"; then
|
||||
echo "MLX not found. Please install MLX"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! python ${lcpp_dir}/convert_hf_to_gguf.py --help; then
|
||||
echo "convert_hf_to_gguf.py not working. Please install llama.cpp python requirements"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# by default use the system binaries (for example from brew)
|
||||
llama_perplexity="llama-perplexity"
|
||||
|
||||
if [[ ! -z "$LLAMA_PERPLEXITY" ]]; then
|
||||
llama_perplexity="$LLAMA_PERPLEXITY"
|
||||
fi
|
||||
|
||||
echo "Using llama-perplexity: ${llama_perplexity}"
|
||||
|
||||
if ! command -v "$llama_perplexity" &> /dev/null; then
|
||||
echo "llama-perplexity not found. Please install it."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
llama_quantize="llama-quantize"
|
||||
|
||||
if [[ ! -z "$LLAMA_QUANTIZE" ]]; then
|
||||
llama_quantize="$LLAMA_QUANTIZE"
|
||||
fi
|
||||
|
||||
echo "Using llama-quantize: ${llama_quantize}"
|
||||
|
||||
if ! command -v "$llama_quantize" &> /dev/null; then
|
||||
echo "llama-quantize not found. Please install it."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
llama_batched_bench="llama-batched-bench"
|
||||
|
||||
if [[ ! -z "$LLAMA_BATCHED_BENCH" ]]; then
|
||||
llama_batched_bench="$LLAMA_BATCHED_BENCH"
|
||||
fi
|
||||
|
||||
echo "Using llama-batched-bench: ${llama_batched_bench}"
|
||||
|
||||
if ! command -v "$llama_batched_bench" &> /dev/null; then
|
||||
echo "llama-batched-bench not found. Please install it."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# get the size in GiB
|
||||
get_size() {
|
||||
local path="$1"
|
||||
local bytes=$(du -s "$path" | awk '{print $1}')
|
||||
local res=$(echo "scale=3; ($bytes*512)/1024/1024/1024" | bc)
|
||||
echo "$res"
|
||||
}
|
||||
|
||||
# parameters:
|
||||
# --no-compute : do not compute anything, just summarize the existing results
|
||||
# --no-ppl : do not compute ppl
|
||||
# --no-perf : do not compute performance (speed) metrics
|
||||
# --no-keep : delete intermediate model files
|
||||
# --num-samples : number of text samples to evaluate (default: 512)
|
||||
# --sequence-length : sequence length of the samples in tokens (default: 512)
|
||||
# --raw-path : file with raw text (such as wikitext)
|
||||
|
||||
# extra agruments
|
||||
args_lcpp="-t 1"
|
||||
|
||||
num_samples=512
|
||||
sequence_length=512
|
||||
raw_path=""
|
||||
no_compute=false
|
||||
no_ppl=false
|
||||
no_perf=false
|
||||
no_keep=false
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--no-compute)
|
||||
no_compute=true
|
||||
shift
|
||||
;;
|
||||
--no-ppl)
|
||||
no_ppl=true
|
||||
shift
|
||||
;;
|
||||
--no-perf)
|
||||
no_perf=true
|
||||
shift
|
||||
;;
|
||||
--no-keep)
|
||||
no_keep=true
|
||||
shift
|
||||
;;
|
||||
--num-samples)
|
||||
num_samples="$2"
|
||||
shift 2
|
||||
;;
|
||||
--sequence-length)
|
||||
sequence_length="$2"
|
||||
shift 2
|
||||
;;
|
||||
--raw-path)
|
||||
raw_path="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
echo "Unknown parameter: $1"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [[ -z "$raw_path" ]]; then
|
||||
echo "No raw path provided"
|
||||
echo "Recommended to use the test set of WikiText from here: https://github.com/ggml-org/llama.cpp/blob/master/scripts/get-wikitext-2.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
eval_model() {
|
||||
org="$1"
|
||||
mid="$2"
|
||||
|
||||
echo "Evaluating ${org}/${mid}"
|
||||
|
||||
huggingface-cli download ${org}/${mid} --local-dir ${org}/${mid}
|
||||
|
||||
# generate and process MLX models
|
||||
|
||||
if [[ "$no_compute" == true ]]; then
|
||||
echo "Skipping computation"
|
||||
else
|
||||
rm -rfv ./${mid}-f32-mlx
|
||||
mlx_lm.convert --hf ./${org}/${mid} --mlx-path ./${mid}-f32-mlx --dtype float32
|
||||
get_size ./${mid}-f32-mlx > ./${mid}-f32-mlx-size.txt
|
||||
|
||||
if [[ "$no_ppl" == false ]]; then
|
||||
python ./mlx-ppl.py --model ./${mid}-f32-mlx --raw-path "$raw_path" --num-samples "$num_samples" --sequence-length "$sequence_length" 2>&1 | tee ./${mid}-f32-mlx-ppl.txt
|
||||
fi
|
||||
|
||||
# no need for F32 perf benchmarks
|
||||
#if [[ "$no_perf" == false ]]; then
|
||||
# mlx_lm.benchmark --model ./${mid}-f32-mlx -p 2048 -g 128 --num-trials 1 2>&1 | tee ./${mid}-f32-mlx-perf-2048.txt
|
||||
# mlx_lm.benchmark --model ./${mid}-f32-mlx -p 4096 -g 128 --num-trials 1 2>&1 | tee ./${mid}-f32-mlx-perf-4096.txt
|
||||
# mlx_lm.benchmark --model ./${mid}-f32-mlx -p 8192 -g 128 --num-trials 1 2>&1 | tee ./${mid}-f32-mlx-perf-8192.txt
|
||||
# mlx_lm.benchmark --model ./${mid}-f32-mlx -p 16384 -g 128 --num-trials 1 2>&1 | tee ./${mid}-f32-mlx-perf-16384.txt
|
||||
# mlx_lm.benchmark --model ./${mid}-f32-mlx -p 32768 -g 128 --num-trials 1 2>&1 | tee ./${mid}-f32-mlx-perf-32768.txt
|
||||
#fi
|
||||
|
||||
if [[ "$no_keep" == true ]]; then
|
||||
echo "Deleting intermediate model files"
|
||||
rm -rfv ./${mid}-f32-mlx
|
||||
fi
|
||||
|
||||
rm -rfv ./${mid}-bf16-mlx
|
||||
mlx_lm.convert --hf ./${org}/${mid} --mlx-path ./${mid}-bf16-mlx --dtype bfloat16
|
||||
get_size ./${mid}-bf16-mlx > ./${mid}-bf16-mlx-size.txt
|
||||
|
||||
if [[ "$no_ppl" == false ]]; then
|
||||
python ./mlx-ppl.py --model ./${mid}-bf16-mlx --raw-path "$raw_path" --num-samples "$num_samples" --sequence-length "$sequence_length" 2>&1 | tee ./${mid}-bf16-mlx-ppl.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_perf" == false ]]; then
|
||||
mlx_lm.benchmark --model ./${mid}-bf16-mlx -p 2048 -g 128 --num-trials 1 2>&1 | tee ./${mid}-bf16-mlx-perf-2048.txt
|
||||
mlx_lm.benchmark --model ./${mid}-bf16-mlx -p 4096 -g 128 --num-trials 1 2>&1 | tee ./${mid}-bf16-mlx-perf-4096.txt
|
||||
mlx_lm.benchmark --model ./${mid}-bf16-mlx -p 8192 -g 128 --num-trials 1 2>&1 | tee ./${mid}-bf16-mlx-perf-8192.txt
|
||||
mlx_lm.benchmark --model ./${mid}-bf16-mlx -p 16384 -g 128 --num-trials 1 2>&1 | tee ./${mid}-bf16-mlx-perf-16384.txt
|
||||
mlx_lm.benchmark --model ./${mid}-bf16-mlx -p 32768 -g 128 --num-trials 1 2>&1 | tee ./${mid}-bf16-mlx-perf-32768.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_keep" == true ]]; then
|
||||
echo "Deleting intermediate model files"
|
||||
rm -rfv ./${mid}-bf16-mlx
|
||||
fi
|
||||
|
||||
rm -rfv ./${mid}-f16-mlx
|
||||
mlx_lm.convert --hf ./${org}/${mid} --mlx-path ./${mid}-f16-mlx --dtype float16
|
||||
get_size ./${mid}-f16-mlx > ./${mid}-f16-mlx-size.txt
|
||||
|
||||
if [[ "$no_ppl" == false ]]; then
|
||||
python ./mlx-ppl.py --model ./${mid}-f16-mlx --raw-path "$raw_path" --num-samples "$num_samples" --sequence-length "$sequence_length" 2>&1 | tee ./${mid}-f16-mlx-ppl.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_perf" == false ]]; then
|
||||
mlx_lm.benchmark --model ./${mid}-f16-mlx -p 2048 -g 128 --num-trials 1 2>&1 | tee ./${mid}-f16-mlx-perf-2048.txt
|
||||
mlx_lm.benchmark --model ./${mid}-f16-mlx -p 4096 -g 128 --num-trials 1 2>&1 | tee ./${mid}-f16-mlx-perf-4096.txt
|
||||
mlx_lm.benchmark --model ./${mid}-f16-mlx -p 8192 -g 128 --num-trials 1 2>&1 | tee ./${mid}-f16-mlx-perf-8192.txt
|
||||
mlx_lm.benchmark --model ./${mid}-f16-mlx -p 16384 -g 128 --num-trials 1 2>&1 | tee ./${mid}-f16-mlx-perf-16384.txt
|
||||
mlx_lm.benchmark --model ./${mid}-f16-mlx -p 32768 -g 128 --num-trials 1 2>&1 | tee ./${mid}-f16-mlx-perf-32768.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_keep" == true ]]; then
|
||||
echo "Deleting intermediate model files"
|
||||
rm -rfv ./${mid}-f16-mlx
|
||||
fi
|
||||
|
||||
rm -rfv ./${mid}-q8-mlx
|
||||
mlx_lm.convert --hf ./${org}/${mid} --mlx-path ./${mid}-q8-mlx --quantize --q-bits 8 --dtype float16
|
||||
get_size ./${mid}-q8-mlx > ./${mid}-q8-mlx-size.txt
|
||||
|
||||
if [[ "$no_ppl" == false ]]; then
|
||||
python ./mlx-ppl.py --model ./${mid}-q8-mlx --raw-path "$raw_path" --num-samples "$num_samples" --sequence-length "$sequence_length" 2>&1 | tee ./${mid}-q8-mlx-ppl.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_perf" == false ]]; then
|
||||
mlx_lm.benchmark --model ./${mid}-q8-mlx -p 2048 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q8-mlx-perf-2048.txt
|
||||
mlx_lm.benchmark --model ./${mid}-q8-mlx -p 4096 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q8-mlx-perf-4096.txt
|
||||
mlx_lm.benchmark --model ./${mid}-q8-mlx -p 8192 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q8-mlx-perf-8192.txt
|
||||
mlx_lm.benchmark --model ./${mid}-q8-mlx -p 16384 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q8-mlx-perf-16384.txt
|
||||
mlx_lm.benchmark --model ./${mid}-q8-mlx -p 32768 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q8-mlx-perf-32768.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_keep" == true ]]; then
|
||||
echo "Deleting intermediate model files"
|
||||
rm -rfv ./${mid}-q8-mlx
|
||||
fi
|
||||
|
||||
#rm -rfv ./${mid}-q6-mlx
|
||||
#mlx_lm.convert --hf ./${org}/${mid} --mlx-path ./${mid}-q6-mlx --quantize --q-bits 6 --dtype float16
|
||||
#get_size ./${mid}-q6-mlx > ./${mid}-q6-mlx-size.txt
|
||||
|
||||
#if [[ "$no_ppl" == false ]]; then
|
||||
# python ./mlx-ppl.py --model ./${mid}-q6-mlx --raw-path "$raw_path" --num-samples "$num_samples" --sequence-length "$sequence_length" 2>&1 | tee ./${mid}-q6-mlx-ppl.txt
|
||||
#fi
|
||||
|
||||
#if [[ "$no_perf" == false ]]; then
|
||||
# mlx_lm.benchmark --model ./${mid}-q6-mlx -p 2048 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q6-mlx-perf-2048.txt
|
||||
# mlx_lm.benchmark --model ./${mid}-q6-mlx -p 4096 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q6-mlx-perf-4096.txt
|
||||
# mlx_lm.benchmark --model ./${mid}-q6-mlx -p 8192 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q6-mlx-perf-8192.txt
|
||||
# mlx_lm.benchmark --model ./${mid}-q6-mlx -p 16384 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q6-mlx-perf-16384.txt
|
||||
# mlx_lm.benchmark --model ./${mid}-q6-mlx -p 32768 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q6-mlx-perf-32768.txt
|
||||
#fi
|
||||
|
||||
#if [[ "$no_keep" == true ]]; then
|
||||
# echo "Deleting intermediate model files"
|
||||
# rm -rfv ./${mid}-q6-mlx
|
||||
#fi
|
||||
|
||||
#rm -rfv ./${mid}-q5-mlx
|
||||
#mlx_lm.convert --hf ./${org}/${mid} --mlx-path ./${mid}-q5-mlx --quantize --q-bits 5 --dtype float16
|
||||
#get_size ./${mid}-q5-mlx > ./${mid}-q5-mlx-size.txt
|
||||
|
||||
#if [[ "$no_ppl" == false ]]; then
|
||||
# python ./mlx-ppl.py --model ./${mid}-q5-mlx --raw-path "$raw_path" --num-samples "$num_samples" --sequence-length "$sequence_length" 2>&1 | tee ./${mid}-q5-mlx-ppl.txt
|
||||
#fi
|
||||
|
||||
#if [[ "$no_perf" == false ]]; then
|
||||
# mlx_lm.benchmark --model ./${mid}-q5-mlx -p 2048 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q5-mlx-perf-2048.txt
|
||||
# mlx_lm.benchmark --model ./${mid}-q5-mlx -p 4096 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q5-mlx-perf-4096.txt
|
||||
# mlx_lm.benchmark --model ./${mid}-q5-mlx -p 8192 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q5-mlx-perf-8192.txt
|
||||
# mlx_lm.benchmark --model ./${mid}-q5-mlx -p 16384 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q5-mlx-perf-16384.txt
|
||||
# mlx_lm.benchmark --model ./${mid}-q5-mlx -p 32768 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q5-mlx-perf-32768.txt
|
||||
#fi
|
||||
|
||||
#if [[ "$no_keep" == true ]]; then
|
||||
# echo "Deleting intermediate model files"
|
||||
# rm -rfv ./${mid}-q5-mlx
|
||||
#fi
|
||||
|
||||
# I think this is something similar to q4_k
|
||||
rm -rfv ./${mid}-q4p-mlx
|
||||
mlx_lm.convert --hf ./${org}/${mid} --mlx-path ./${mid}-q4p-mlx --quantize --quant-predicate mixed_4_6 --dtype float16
|
||||
get_size ./${mid}-q4p-mlx > ./${mid}-q4p-mlx-size.txt
|
||||
|
||||
if [[ "$no_ppl" == false ]]; then
|
||||
python ./mlx-ppl.py --model ./${mid}-q4p-mlx --raw-path "$raw_path" --num-samples "$num_samples" --sequence-length "$sequence_length" 2>&1 | tee ./${mid}-q4p-mlx-ppl.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_perf" == false ]]; then
|
||||
mlx_lm.benchmark --model ./${mid}-q4p-mlx -p 2048 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q4p-mlx-perf-2048.txt
|
||||
mlx_lm.benchmark --model ./${mid}-q4p-mlx -p 4096 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q4p-mlx-perf-4096.txt
|
||||
mlx_lm.benchmark --model ./${mid}-q4p-mlx -p 8192 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q4p-mlx-perf-8192.txt
|
||||
mlx_lm.benchmark --model ./${mid}-q4p-mlx -p 16384 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q4p-mlx-perf-16384.txt
|
||||
mlx_lm.benchmark --model ./${mid}-q4p-mlx -p 32768 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q4p-mlx-perf-32768.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_keep" == true ]]; then
|
||||
echo "Deleting intermediate model files"
|
||||
rm -rfv ./${mid}-q4p-mlx
|
||||
fi
|
||||
|
||||
rm -rfv ./${mid}-q4-mlx
|
||||
mlx_lm.convert --hf ./${org}/${mid} --mlx-path ./${mid}-q4-mlx --quantize --q-bits 4 --dtype float16
|
||||
get_size ./${mid}-q4-mlx > ./${mid}-q4-mlx-size.txt
|
||||
|
||||
if [[ "$no_ppl" == false ]]; then
|
||||
python ./mlx-ppl.py --model ./${mid}-q4-mlx --raw-path "$raw_path" --num-samples "$num_samples" --sequence-length "$sequence_length" 2>&1 | tee ./${mid}-q4-mlx-ppl.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_perf" == false ]]; then
|
||||
mlx_lm.benchmark --model ./${mid}-q4-mlx -p 2048 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q4-mlx-perf-2048.txt
|
||||
mlx_lm.benchmark --model ./${mid}-q4-mlx -p 4096 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q4-mlx-perf-4096.txt
|
||||
mlx_lm.benchmark --model ./${mid}-q4-mlx -p 8192 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q4-mlx-perf-8192.txt
|
||||
mlx_lm.benchmark --model ./${mid}-q4-mlx -p 16384 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q4-mlx-perf-16384.txt
|
||||
mlx_lm.benchmark --model ./${mid}-q4-mlx -p 32768 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q4-mlx-perf-32768.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_keep" == true ]]; then
|
||||
echo "Deleting intermediate model files"
|
||||
rm -rfv ./${mid}-q4-mlx
|
||||
fi
|
||||
|
||||
rm -rfv ./${mid}-q3-mlx
|
||||
mlx_lm.convert --hf ./${org}/${mid} --mlx-path ./${mid}-q3-mlx --quantize --q-bits 3 --dtype float16
|
||||
get_size ./${mid}-q3-mlx > ./${mid}-q3-mlx-size.txt
|
||||
|
||||
if [[ "$no_ppl" == false ]]; then
|
||||
python ./mlx-ppl.py --model ./${mid}-q3-mlx --raw-path "$raw_path" --num-samples "$num_samples" --sequence-length "$sequence_length" 2>&1 | tee ./${mid}-q3-mlx-ppl.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_perf" == false ]]; then
|
||||
mlx_lm.benchmark --model ./${mid}-q3-mlx -p 2048 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q3-mlx-perf-2048.txt
|
||||
mlx_lm.benchmark --model ./${mid}-q3-mlx -p 4096 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q3-mlx-perf-4096.txt
|
||||
mlx_lm.benchmark --model ./${mid}-q3-mlx -p 8192 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q3-mlx-perf-8192.txt
|
||||
mlx_lm.benchmark --model ./${mid}-q3-mlx -p 16384 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q3-mlx-perf-16384.txt
|
||||
mlx_lm.benchmark --model ./${mid}-q3-mlx -p 32768 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q3-mlx-perf-32768.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_keep" == true ]]; then
|
||||
echo "Deleting intermediate model files"
|
||||
rm -rfv ./${mid}-q3-mlx
|
||||
fi
|
||||
fi
|
||||
|
||||
# generate and process llama.cpp GGUF models
|
||||
|
||||
if [[ "$no_compute" == true ]]; then
|
||||
echo "Skipping computation"
|
||||
else
|
||||
# the F32 model is the reference - we generate all other models from it
|
||||
mkdir -p ./${mid}-f32-gguf
|
||||
python ${lcpp_dir}/convert_hf_to_gguf.py ./${org}/${mid} --outtype f32 --outfile ./${mid}-f32-gguf/model.gguf
|
||||
get_size ./${mid}-f32-gguf > ./${mid}-f32-gguf-size.txt
|
||||
|
||||
if [[ "$no_ppl" == false ]]; then
|
||||
${llama_perplexity} $args_lcpp -m ./${mid}-f32-gguf/model.gguf -f "$raw_path" --chunks "${num_samples}" -c "${sequence_length}" 2>&1 | tee ./${mid}-f32-gguf-ppl.txt
|
||||
fi
|
||||
|
||||
# no need for F32 perf benchmarks
|
||||
#if [[ "$no_perf" == false ]]; then
|
||||
# ${llama_batched_bench} $args_lcpp -m ./${mid}-f32-gguf/model.gguf -c 33768 -b 2048 -ub 2048 -npp 2048,4096,8192,16384,32768 -ntg 128 -npl 1 2>&1 | tee ./${mid}-f32-gguf-perf.txt
|
||||
#fi
|
||||
|
||||
# this requires to explicitly build llama.cpp with BF16 support
|
||||
rm -rfv ./${mid}-bf16-gguf && mkdir -p ./${mid}-bf16-gguf
|
||||
${llama_quantize} ./${mid}-f32-gguf/model.gguf ./${mid}-bf16-gguf/model.gguf bf16
|
||||
get_size ./${mid}-bf16-gguf > ./${mid}-bf16-gguf-size.txt
|
||||
|
||||
if [[ "$no_ppl" == false ]]; then
|
||||
${llama_perplexity} $args_lcpp -m ./${mid}-bf16-gguf/model.gguf -f "$raw_path" --chunks "${num_samples}" -c "${sequence_length}" 2>&1 | tee ./${mid}-bf16-gguf-ppl.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_perf" == false ]]; then
|
||||
${llama_batched_bench} $args_lcpp -m ./${mid}-bf16-gguf/model.gguf -c 33768 -b 2048 -ub 2048 -npp 2048,4096,8192,16384,32768 -ntg 128 -npl 1 2>&1 | tee ./${mid}-bf16-gguf-perf.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_keep" == true ]]; then
|
||||
echo "Deleting intermediate model files"
|
||||
rm -rfv ./${mid}-bf16-gguf
|
||||
fi
|
||||
|
||||
rm -rfv ./${mid}-f16-gguf && mkdir -p ./${mid}-f16-gguf
|
||||
${llama_quantize} ./${mid}-f32-gguf/model.gguf ./${mid}-f16-gguf/model.gguf f16
|
||||
get_size ./${mid}-f16-gguf > ./${mid}-f16-gguf-size.txt
|
||||
|
||||
if [[ "$no_ppl" == false ]]; then
|
||||
${llama_perplexity} $args_lcpp -m ./${mid}-f16-gguf/model.gguf -f "$raw_path" --chunks "${num_samples}" -c "${sequence_length}" 2>&1 | tee ./${mid}-f16-gguf-ppl.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_perf" == false ]]; then
|
||||
${llama_batched_bench} $args_lcpp -m ./${mid}-f16-gguf/model.gguf -c 33768 -b 2048 -ub 2048 -npp 2048,4096,8192,16384,32768 -ntg 128 -npl 1 2>&1 | tee ./${mid}-f16-gguf-perf.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_keep" == true ]]; then
|
||||
echo "Deleting intermediate model files"
|
||||
rm -rfv ./${mid}-f16-gguf
|
||||
fi
|
||||
|
||||
rm -rfv ./${mid}-q8-gguf && mkdir -p ./${mid}-q8-gguf
|
||||
${llama_quantize} ./${mid}-f32-gguf/model.gguf ./${mid}-q8-gguf/model.gguf q8_0
|
||||
get_size ./${mid}-q8-gguf > ./${mid}-q8-gguf-size.txt
|
||||
|
||||
if [[ "$no_ppl" == false ]]; then
|
||||
${llama_perplexity} $args_lcpp -m ./${mid}-q8-gguf/model.gguf -f "$raw_path" --chunks "${num_samples}" -c "${sequence_length}" 2>&1 | tee ./${mid}-q8-gguf-ppl.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_perf" == false ]]; then
|
||||
${llama_batched_bench} $args_lcpp -m ./${mid}-q8-gguf/model.gguf -c 33768 -b 2048 -ub 2048 -npp 2048,4096,8192,16384,32768 -ntg 128 -npl 1 2>&1 | tee ./${mid}-q8-gguf-perf.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_keep" == true ]]; then
|
||||
echo "Deleting intermediate model files"
|
||||
rm -rfv ./${mid}-q8-gguf
|
||||
fi
|
||||
|
||||
#rm -rfv ./${mid}-q6-gguf && mkdir -p ./${mid}-q6-gguf
|
||||
#${llama_quantize} ./${mid}-f32-gguf/model.gguf ./${mid}-q6-gguf/model.gguf q6_k
|
||||
#get_size ./${mid}-q6-gguf > ./${mid}-q6-gguf-size.txt
|
||||
|
||||
#if [[ "$no_ppl" == false ]]; then
|
||||
# ${llama_perplexity} $args_lcpp -m ./${mid}-q6-gguf/model.gguf -f "$raw_path" --chunks "${num_samples}" -c "${sequence_length}" 2>&1 | tee ./${mid}-q6-gguf-ppl.txt
|
||||
#fi
|
||||
|
||||
#if [[ "$no_perf" == false ]]; then
|
||||
# ${llama_batched_bench} $args_lcpp -m ./${mid}-q6-gguf/model.gguf -c 33768 -b 2048 -ub 2048 -npp 2048,4096,8192,16384,32768 -ntg 128 -npl 1 2>&1 | tee ./${mid}-q6-gguf-perf.txt
|
||||
#fi
|
||||
|
||||
#if [[ "$no_keep" == true ]]; then
|
||||
# echo "Deleting intermediate model files"
|
||||
# rm -rfv ./${mid}-q6-gguf
|
||||
#fi
|
||||
|
||||
#rm -rfv ./${mid}-q5-gguf && mkdir -p ./${mid}-q5-gguf
|
||||
#${llama_quantize} ./${mid}-f32-gguf/model.gguf ./${mid}-q5-gguf/model.gguf q5_k_s
|
||||
#get_size ./${mid}-q5-gguf > ./${mid}-q5-gguf-size.txt
|
||||
|
||||
#if [[ "$no_ppl" == false ]]; then
|
||||
# ${llama_perplexity} $args_lcpp -m ./${mid}-q5-gguf/model.gguf -f "$raw_path" --chunks "${num_samples}" -c "${sequence_length}" 2>&1 | tee ./${mid}-q5-gguf-ppl.txt
|
||||
#fi
|
||||
|
||||
#if [[ "$no_perf" == false ]]; then
|
||||
# ${llama_batched_bench} $args_lcpp -m ./${mid}-q5-gguf/model.gguf -c 33768 -b 2048 -ub 2048 -npp 2048,4096,8192,16384,32768 -ntg 128 -npl 1 2>&1 | tee ./${mid}-q5-gguf-perf.txt
|
||||
#fi
|
||||
|
||||
#if [[ "$no_keep" == true ]]; then
|
||||
# echo "Deleting intermediate model files"
|
||||
# rm -rfv ./${mid}-q5-gguf
|
||||
#fi
|
||||
|
||||
rm -rfv ./${mid}-q4p-gguf && mkdir -p ./${mid}-q4p-gguf
|
||||
${llama_quantize} ./${mid}-f32-gguf/model.gguf ./${mid}-q4p-gguf/model.gguf q4_k
|
||||
get_size ./${mid}-q4p-gguf > ./${mid}-q4p-gguf-size.txt
|
||||
|
||||
if [[ "$no_ppl" == false ]]; then
|
||||
${llama_perplexity} $args_lcpp -m ./${mid}-q4p-gguf/model.gguf -f "$raw_path" --chunks "${num_samples}" -c "${sequence_length}" 2>&1 | tee ./${mid}-q4p-gguf-ppl.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_perf" == false ]]; then
|
||||
${llama_batched_bench} $args_lcpp -m ./${mid}-q4p-gguf/model.gguf -c 33768 -b 2048 -ub 2048 -npp 2048,4096,8192,16384,32768 -ntg 128 -npl 1 2>&1 | tee ./${mid}-q4p-gguf-perf.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_keep" == true ]]; then
|
||||
echo "Deleting intermediate model files"
|
||||
rm -rfv ./${mid}-q4p-gguf
|
||||
fi
|
||||
|
||||
# note: we use --pure here to match the MLX quantization of the embeddings
|
||||
rm -rfv ./${mid}-q4-gguf && mkdir -p ./${mid}-q4-gguf
|
||||
${llama_quantize} --pure ./${mid}-f32-gguf/model.gguf ./${mid}-q4-gguf/model.gguf q4_0
|
||||
get_size ./${mid}-q4-gguf > ./${mid}-q4-gguf-size.txt
|
||||
|
||||
if [[ "$no_ppl" == false ]]; then
|
||||
${llama_perplexity} $args_lcpp -m ./${mid}-q4-gguf/model.gguf -f "$raw_path" --chunks "${num_samples}" -c "${sequence_length}" 2>&1 | tee ./${mid}-q4-gguf-ppl.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_perf" == false ]]; then
|
||||
${llama_batched_bench} $args_lcpp -m ./${mid}-q4-gguf/model.gguf -c 33768 -b 2048 -ub 2048 -npp 2048,4096,8192,16384,32768 -ntg 128 -npl 1 2>&1 | tee ./${mid}-q4-gguf-perf.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_keep" == true ]]; then
|
||||
echo "Deleting intermediate model files"
|
||||
rm -rfv ./${mid}-q4-gguf
|
||||
fi
|
||||
|
||||
rm -rfv ./${mid}-q3-gguf && mkdir -p ./${mid}-q3-gguf
|
||||
${llama_quantize} ./${mid}-f32-gguf/model.gguf ./${mid}-q3-gguf/model.gguf q3_k_s
|
||||
get_size ./${mid}-q3-gguf > ./${mid}-q3-gguf-size.txt
|
||||
|
||||
if [[ "$no_ppl" == false ]]; then
|
||||
${llama_perplexity} $args_lcpp -m ./${mid}-q3-gguf/model.gguf -f "$raw_path" --chunks "${num_samples}" -c "${sequence_length}" 2>&1 | tee ./${mid}-q3-gguf-ppl.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_perf" == false ]]; then
|
||||
${llama_batched_bench} $args_lcpp -m ./${mid}-q3-gguf/model.gguf -c 33768 -b 2048 -ub 2048 -npp 2048,4096,8192,16384,32768 -ntg 128 -npl 1 2>&1 | tee ./${mid}-q3-gguf-perf.txt
|
||||
fi
|
||||
|
||||
if [[ "$no_keep" == true ]]; then
|
||||
echo "Deleting intermediate model files"
|
||||
rm -rfv ./${mid}-q3-gguf
|
||||
fi
|
||||
|
||||
# remove the f32 model at the end
|
||||
if [[ "$no_keep" == true ]]; then
|
||||
rm -rfv ./${mid}-f32-gguf
|
||||
fi
|
||||
fi
|
||||
|
||||
set +x
|
||||
|
||||
# analyze results
|
||||
|
||||
#types=("f32" "bf16" "f16" "q8" "q6" "q5" "q4p" "q4" "q3")
|
||||
types=("f32" "bf16" "f16" "q8" "q4p" "q4" "q3")
|
||||
|
||||
mlx_ppls=()
|
||||
mlx_ppl_deltas=()
|
||||
mlx_sizes=()
|
||||
mlx_pps2k=()
|
||||
mlx_tgs2k=()
|
||||
mlx_pps4k=()
|
||||
mlx_tgs4k=()
|
||||
mlx_pps8k=()
|
||||
mlx_tgs8k=()
|
||||
mlx_pps16k=()
|
||||
mlx_tgs16k=()
|
||||
mlx_pps32k=()
|
||||
mlx_tgs32k=()
|
||||
|
||||
# mlx:
|
||||
for t in ${types[*]}; do
|
||||
cur_ppl="N/A"
|
||||
cur_ppl_delta="N/A"
|
||||
cur_size="N/A"
|
||||
cur_pp2k="N/A"
|
||||
cur_tg2k="N/A"
|
||||
cur_pp4k="N/A"
|
||||
cur_tg4k="N/A"
|
||||
cur_pp8k="N/A"
|
||||
cur_tg8k="N/A"
|
||||
cur_pp16k="N/A"
|
||||
cur_tg16k="N/A"
|
||||
cur_pp32k="N/A"
|
||||
cur_tg32k="N/A"
|
||||
|
||||
if [[ -f ./${mid}-${t}-mlx-ppl.txt ]]; then
|
||||
cur_ppl=$(grep -o 'Perplexity: [0-9.]*' ./${mid}-${t}-mlx-ppl.txt | cut -d' ' -f2)
|
||||
cur_ppl_delta=$(grep -o 'Perplexity: [0-9.]* ± [0-9.]*' ./${mid}-${t}-mlx-ppl.txt | cut -d' ' -f4)
|
||||
cur_size=$(cat ./${mid}-${t}-mlx-size.txt)
|
||||
cur_pp2k=$(grep -o 'Averages.*prompt_tps=[0-9.]*' ./${mid}-${t}-mlx-perf-2048.txt | cut -d'=' -f2)
|
||||
cur_tg2k=$(grep -o 'Averages.*generation_tps=[0-9.]*' ./${mid}-${t}-mlx-perf-2048.txt | cut -d'=' -f3)
|
||||
cur_pp4k=$(grep -o 'Averages.*prompt_tps=[0-9.]*' ./${mid}-${t}-mlx-perf-4096.txt | cut -d'=' -f2)
|
||||
cur_tg4k=$(grep -o 'Averages.*generation_tps=[0-9.]*' ./${mid}-${t}-mlx-perf-4096.txt | cut -d'=' -f3)
|
||||
cur_pp8k=$(grep -o 'Averages.*prompt_tps=[0-9.]*' ./${mid}-${t}-mlx-perf-8192.txt | cut -d'=' -f2)
|
||||
cur_tg8k=$(grep -o 'Averages.*generation_tps=[0-9.]*' ./${mid}-${t}-mlx-perf-8192.txt | cut -d'=' -f3)
|
||||
cur_pp16k=$(grep -o 'Averages.*prompt_tps=[0-9.]*' ./${mid}-${t}-mlx-perf-16384.txt | cut -d'=' -f2)
|
||||
cur_tg16k=$(grep -o 'Averages.*generation_tps=[0-9.]*' ./${mid}-${t}-mlx-perf-16384.txt | cut -d'=' -f3)
|
||||
cur_pp32k=$(grep -o 'Averages.*prompt_tps=[0-9.]*' ./${mid}-${t}-mlx-perf-32768.txt | cut -d'=' -f2)
|
||||
cur_tg32k=$(grep -o 'Averages.*generation_tps=[0-9.]*' ./${mid}-${t}-mlx-perf-32768.txt | cut -d'=' -f3)
|
||||
fi
|
||||
|
||||
mlx_ppls+=("${cur_ppl}")
|
||||
mlx_ppl_deltas+=("${cur_ppl_delta}")
|
||||
mlx_sizes+=("${cur_size}")
|
||||
mlx_pps2k+=("${cur_pp2k}")
|
||||
mlx_tgs2k+=("${cur_tg2k}")
|
||||
mlx_pps4k+=("${cur_pp4k}")
|
||||
mlx_tgs4k+=("${cur_tg4k}")
|
||||
mlx_pps8k+=("${cur_pp8k}")
|
||||
mlx_tgs8k+=("${cur_tg8k}")
|
||||
mlx_pps16k+=("${cur_pp16k}")
|
||||
mlx_tgs16k+=("${cur_tg16k}")
|
||||
mlx_pps32k+=("${cur_pp32k}")
|
||||
mlx_tgs32k+=("${cur_tg32k}")
|
||||
done
|
||||
|
||||
gguf_ppls=()
|
||||
gguf_ppl_deltas=()
|
||||
gguf_sizes=()
|
||||
gguf_pps2k=()
|
||||
gguf_tgs2k=()
|
||||
gguf_pps4k=()
|
||||
gguf_tgs4k=()
|
||||
gguf_pps8k=()
|
||||
gguf_tgs8k=()
|
||||
gguf_pps16k=()
|
||||
gguf_tgs16k=()
|
||||
gguf_pps32k=()
|
||||
gguf_tgs32k=()
|
||||
|
||||
# gguf:
|
||||
for t in ${types[*]}; do
|
||||
cur_ppl="N/A"
|
||||
cur_ppl_delta="N/A"
|
||||
cur_size="N/A"
|
||||
cur_pp2k="N/A"
|
||||
cur_tg2k="N/A"
|
||||
cur_pp4k="N/A"
|
||||
cur_tg4k="N/A"
|
||||
cur_pp8k="N/A"
|
||||
cur_tg8k="N/A"
|
||||
cur_pp16k="N/A"
|
||||
cur_tg16k="N/A"
|
||||
cur_pp32k="N/A"
|
||||
cur_tg32k="N/A"
|
||||
|
||||
if [[ -f ./${mid}-${t}-gguf-ppl.txt ]]; then
|
||||
cur_ppl=$(grep -o 'Final estimate: PPL = [0-9.]*' ./${mid}-${t}-gguf-ppl.txt | sed -e "s/.*Final//" | cut -d' ' -f5)
|
||||
cur_ppl_delta=$(grep -o 'Final estimate: PPL = [0-9.]* +/- [0-9.]*' ./${mid}-${t}-gguf-ppl.txt | sed -e "s/.*Final//" | cut -d' ' -f7)
|
||||
cur_size=$(cat ./${mid}-${t}-gguf-size.txt)
|
||||
cur_pp2k=$(grep -o '| 2048 |.*' ./${mid}-${t}-gguf-perf.txt | awk '{print $12}')
|
||||
cur_tg2k=$(grep -o '| 2048 |.*' ./${mid}-${t}-gguf-perf.txt | awk '{print $16}')
|
||||
cur_pp4k=$(grep -o '| 4096 |.*' ./${mid}-${t}-gguf-perf.txt | awk '{print $12}')
|
||||
cur_tg4k=$(grep -o '| 4096 |.*' ./${mid}-${t}-gguf-perf.txt | awk '{print $16}')
|
||||
cur_pp8k=$(grep -o '| 8192 |.*' ./${mid}-${t}-gguf-perf.txt | awk '{print $12}')
|
||||
cur_tg8k=$(grep -o '| 8192 |.*' ./${mid}-${t}-gguf-perf.txt | awk '{print $16}')
|
||||
cur_pp16k=$(grep -o '| 16384 |.*' ./${mid}-${t}-gguf-perf.txt | awk '{print $12}')
|
||||
cur_tg16k=$(grep -o '| 16384 |.*' ./${mid}-${t}-gguf-perf.txt | awk '{print $16}')
|
||||
cur_pp32k=$(grep -o '| 32768 |.*' ./${mid}-${t}-gguf-perf.txt | awk '{print $12}')
|
||||
cur_tg32k=$(grep -o '| 32768 |.*' ./${mid}-${t}-gguf-perf.txt | awk '{print $16}')
|
||||
fi
|
||||
|
||||
gguf_ppls+=("${cur_ppl}")
|
||||
gguf_ppl_deltas+=("${cur_ppl_delta}")
|
||||
gguf_sizes+=("${cur_size}")
|
||||
gguf_pps2k+=("${cur_pp2k}")
|
||||
gguf_tgs2k+=("${cur_tg2k}")
|
||||
gguf_pps4k+=("${cur_pp4k}")
|
||||
gguf_tgs4k+=("${cur_tg4k}")
|
||||
gguf_pps8k+=("${cur_pp8k}")
|
||||
gguf_tgs8k+=("${cur_tg8k}")
|
||||
gguf_pps16k+=("${cur_pp16k}")
|
||||
gguf_tgs16k+=("${cur_tg16k}")
|
||||
gguf_pps32k+=("${cur_pp32k}")
|
||||
gguf_tgs32k+=("${cur_tg32k}")
|
||||
done
|
||||
|
||||
res="${mid}-results.txt"
|
||||
echo "Results for ${org}/${mid} saved to ${res}"
|
||||
|
||||
printf "\n" | tee ${res}
|
||||
printf "Model ID: ${org}/${mid}\n" | tee -a ${res}
|
||||
#printf "Samples: ${num_samples}\n" | tee -a ${res}
|
||||
#printf "Sequence Length: ${sequence_length}\n" | tee -a ${res}
|
||||
printf "\n" | tee -a ${res}
|
||||
printf "| Type | MLX PPL | GGUF PPL | MLX Size | GGUF Size | MLX PP 2K | GGUF PP 2K | MLX TG 2K | GGUF TG 2K | MLX PP 4K | GGUF PP 4K | MLX TG 4K | GGUF TG 4K | MLX PP 8K | GGUF PP 8K | MLX TG 8K | GGUF TG 8K | MLX PP 16K | GGUF PP 16K | MLX TG 16K | GGUF TG 16K | MLX PP 32K | GGUF PP 32K | MLX TG 32K | GGUF TG 32K |\n" | tee -a ${res}
|
||||
printf "|-------|---------------------|------------------------|----------|-----------| ---------- | ----------- | ---------- | ----------- | ---------- | ----------- | ---------- | ----------- | ---------- | ----------- | ---------- | ----------- | ---------- | ----------- | ---------- | ----------- | ---------- | ----------- | ---------- | ----------- |\n" | tee -a ${res}
|
||||
|
||||
for i in "${!types[@]}"; do
|
||||
printf "| %-5s | %10s ± %6s | %10s ± %9s | %8s | %9s | %10s | %11s | %10s | %11s | %10s | %11s | %10s | %11s | %10s | %11s | %10s | %11s | %10s | %11s | %10s | %11s | %10s | %11s | %10s | %11s |\n" \
|
||||
"${types[i]}" \
|
||||
"${mlx_ppls[i]}" \
|
||||
"${mlx_ppl_deltas[i]}" \
|
||||
"${gguf_ppls[i]}" \
|
||||
"${gguf_ppl_deltas[i]}" \
|
||||
"${mlx_sizes[i]}" \
|
||||
"${gguf_sizes[i]}" \
|
||||
"${mlx_pps2k[i]}" \
|
||||
"${gguf_pps2k[i]}" \
|
||||
"${mlx_tgs2k[i]}" \
|
||||
"${gguf_tgs2k[i]}" \
|
||||
"${mlx_pps4k[i]}" \
|
||||
"${gguf_pps4k[i]}" \
|
||||
"${mlx_tgs4k[i]}" \
|
||||
"${gguf_tgs4k[i]}" \
|
||||
"${mlx_pps8k[i]}" \
|
||||
"${gguf_pps8k[i]}" \
|
||||
"${mlx_tgs8k[i]}" \
|
||||
"${gguf_tgs8k[i]}" \
|
||||
"${mlx_pps16k[i]}" \
|
||||
"${gguf_pps16k[i]}" \
|
||||
"${mlx_tgs16k[i]}" \
|
||||
"${gguf_tgs16k[i]}" \
|
||||
"${mlx_pps32k[i]}" \
|
||||
"${gguf_pps32k[i]}" \
|
||||
"${mlx_tgs32k[i]}" \
|
||||
"${gguf_tgs32k[i]}" | tee -a ${res}
|
||||
done
|
||||
}
|
||||
|
||||
eval_model "meta-llama" "Llama-3.2-1B"
|
||||
eval_model "meta-llama" "Llama-3.2-3B"
|
||||
eval_model "meta-llama" "Llama-3.1-8B"
|
||||
|
||||
eval_model "google" "gemma-3-270m"
|
||||
eval_model "google" "gemma-3-1b-pt"
|
||||
#eval_model "google" "gemma-3-4b-pt"
|
||||
|
||||
# the mlx-ppl.y script does not work with these models - not sure why
|
||||
#eval_model "google" "gemma-3n-E2B"
|
||||
#eval_model "google" "gemma-3n-E4B"
|
||||
|
||||
eval_model "Qwen" "Qwen3-0.6B-Base"
|
||||
eval_model "Qwen" "Qwen3-1.7B-Base"
|
||||
eval_model "Qwen" "Qwen3-4B-Base"
|
||||
eval_model "Qwen" "Qwen3-8B-Base"
|
||||
eval_model "Qwen" "Qwen3-30B-A3B-Base"
|
||||
120
examples/compare-mlx/inspect_model.py
Normal file
120
examples/compare-mlx/inspect_model.py
Normal file
@@ -0,0 +1,120 @@
|
||||
#!/usr/bin/env python3
|
||||
# generated by Claude
|
||||
"""
|
||||
Script to inspect SafeTensors model files and print tensor information.
|
||||
"""
|
||||
|
||||
import json
|
||||
from safetensors import safe_open
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
def inspect_safetensors_model(model_dir="."):
|
||||
"""Inspect all SafeTensors files in the model directory."""
|
||||
|
||||
# First, let's read the index file to see the file structure
|
||||
index_file = Path(model_dir) / "model.safetensors.index.json"
|
||||
|
||||
if index_file.exists():
|
||||
with open(index_file, 'r') as f:
|
||||
index_data = json.load(f)
|
||||
|
||||
print("=== Model Structure ===")
|
||||
print(f"Total parameters: {index_data.get('metadata', {}).get('total_size', 'Unknown')}")
|
||||
print()
|
||||
|
||||
# Get all safetensor files
|
||||
safetensor_files = set(index_data.get('weight_map', {}).values())
|
||||
else:
|
||||
# If no index file, look for safetensor files directly
|
||||
safetensor_files = [f for f in os.listdir(model_dir) if f.endswith('.safetensors')]
|
||||
|
||||
# Sort files for consistent output
|
||||
safetensor_files = sorted(safetensor_files)
|
||||
|
||||
print("=== Tensor Information ===")
|
||||
print(f"{'Tensor Name':<50} {'Shape':<25} {'Data Type':<15} {'File'}")
|
||||
print("-" * 110)
|
||||
|
||||
total_tensors = 0
|
||||
|
||||
for filename in safetensor_files:
|
||||
filepath = Path(model_dir) / filename
|
||||
if not filepath.exists():
|
||||
continue
|
||||
|
||||
print(f"\n--- {filename} ---")
|
||||
|
||||
# Open and inspect the safetensor file
|
||||
with safe_open(filepath, framework="pt") as f: # Use PyTorch framework for better dtype support
|
||||
tensor_names = f.keys()
|
||||
|
||||
for tensor_name in sorted(tensor_names):
|
||||
# Get tensor metadata without loading the full tensor
|
||||
tensor_slice = f.get_slice(tensor_name)
|
||||
shape = tensor_slice.get_shape()
|
||||
dtype = tensor_slice.get_dtype()
|
||||
|
||||
shape_str = str(tuple(shape))
|
||||
dtype_str = str(dtype)
|
||||
|
||||
print(f"{tensor_name:<50} {shape_str:<25} {dtype_str:<15} {filename}")
|
||||
total_tensors += 1
|
||||
|
||||
print(f"\nTotal tensors found: {total_tensors}")
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Inspect SafeTensors model files")
|
||||
parser.add_argument("--model-dir", "-d", default=".",
|
||||
help="Directory containing the model files (default: current directory)")
|
||||
parser.add_argument("--summary", "-s", action="store_true",
|
||||
help="Show only summary statistics")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.summary:
|
||||
print_summary_only(args.model_dir)
|
||||
else:
|
||||
inspect_safetensors_model(args.model_dir)
|
||||
|
||||
def print_summary_only(model_dir="."):
|
||||
"""Print only summary statistics."""
|
||||
safetensor_files = [f for f in os.listdir(model_dir) if f.endswith('.safetensors')]
|
||||
|
||||
total_tensors = 0
|
||||
dtype_counts = {}
|
||||
total_params = 0
|
||||
|
||||
for filename in sorted(safetensor_files):
|
||||
filepath = Path(model_dir) / filename
|
||||
if not filepath.exists():
|
||||
continue
|
||||
|
||||
with safe_open(filepath, framework="pt") as f: # Use PyTorch framework
|
||||
for tensor_name in f.keys():
|
||||
tensor_slice = f.get_slice(tensor_name)
|
||||
shape = tensor_slice.get_shape()
|
||||
dtype = tensor_slice.get_dtype()
|
||||
|
||||
total_tensors += 1
|
||||
|
||||
dtype_str = str(dtype)
|
||||
dtype_counts[dtype_str] = dtype_counts.get(dtype_str, 0) + 1
|
||||
|
||||
# Calculate parameter count
|
||||
param_count = 1
|
||||
for dim in shape:
|
||||
param_count *= dim
|
||||
total_params += param_count
|
||||
|
||||
print("=== Model Summary ===")
|
||||
print(f"Total tensors: {total_tensors}")
|
||||
print(f"Total parameters: {total_params:,}")
|
||||
print(f"Data type distribution:")
|
||||
for dtype, count in sorted(dtype_counts.items()):
|
||||
print(f" {dtype}: {count} tensors")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
305
examples/compare-mlx/mlx-ppl.py
Normal file
305
examples/compare-mlx/mlx-ppl.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
# modified: https://github.com/ml-explore/mlx-lm/blob/60320dc2347d45dc3ca08be90e5255fb9424bb09/mlx_lm/perplexity.py
|
||||
"""
|
||||
Evaluate perplexity (PPL) of pre-trained MLX models in the same way as llama.cpp's llama-perplexity.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from mlx_lm.tuner.datasets import load_dataset
|
||||
from mlx_lm.tuner.utils import get_total_parameters
|
||||
from mlx_lm.utils import load
|
||||
|
||||
|
||||
def load_data(
|
||||
tokenizer,
|
||||
data_path: str,
|
||||
num_samples: int,
|
||||
sequence_length: int,
|
||||
):
|
||||
"""
|
||||
Load a Hugging‑Face dataset (via mlx‑lm’s dataset utilities) and convert it
|
||||
into a token tensor of shape (N, sequence_length).
|
||||
"""
|
||||
args = types.SimpleNamespace(
|
||||
hf_dataset={
|
||||
"path": data_path,
|
||||
"train_split": "train",
|
||||
"valid_split": "train[:1]",
|
||||
},
|
||||
train=True,
|
||||
test=False,
|
||||
)
|
||||
dataset = load_dataset(args, tokenizer)[0]
|
||||
|
||||
perm = np.random.permutation(len(dataset)).tolist()
|
||||
|
||||
num_tokens = sequence_length * num_samples if num_samples > 0 else float("inf")
|
||||
data = []
|
||||
i = 0
|
||||
while len(data) < num_tokens:
|
||||
tokens, _ = dataset.process(dataset[perm[i]])
|
||||
i += 1
|
||||
data.extend(tokens)
|
||||
|
||||
# Convert to MX array, truncate to a multiple of `sequence_length`
|
||||
data = mx.array(data[: (len(data) // sequence_length) * sequence_length])
|
||||
data = data.reshape(-1, sequence_length)
|
||||
if num_samples > 0:
|
||||
data = data[:num_samples]
|
||||
return data
|
||||
|
||||
|
||||
def _tokenize_text(tokenizer, text: str):
|
||||
"""
|
||||
Helper that tokenises a string using the MLX‑LM tokenizer.
|
||||
Supports the common `encode` method or a callable tokenizer.
|
||||
"""
|
||||
# Most mlx‑lm tokenizers expose an `encode` method.
|
||||
if hasattr(tokenizer, "encode"):
|
||||
tokens = tokenizer.encode(text)
|
||||
elif callable(tokenizer):
|
||||
tokens = tokenizer(text)
|
||||
else:
|
||||
raise AttributeError(
|
||||
"Tokenizer does not have an `encode` method nor is it callable."
|
||||
)
|
||||
# Normalise the output to a Python list of ints.
|
||||
if isinstance(tokens, mx.array):
|
||||
tokens = tokens.tolist()
|
||||
return tokens
|
||||
|
||||
|
||||
# load a raw text file and tokenize it
|
||||
# generated with gpt-oss-120b
|
||||
def load_raw_data(
|
||||
tokenizer,
|
||||
raw_path: str,
|
||||
num_samples: int,
|
||||
sequence_length: int,
|
||||
):
|
||||
"""
|
||||
Load a raw text file, tokenize it, and reshape into a (N, sequence_length)
|
||||
tensor suitable for perplexity evaluation.
|
||||
"""
|
||||
if not os.path.isfile(raw_path):
|
||||
raise FileNotFoundError(f"Raw text file not found: {raw_path}")
|
||||
|
||||
# Read the whole file (UTF‑8). Users can supply any plain‑text corpus.
|
||||
with open(raw_path, "r", encoding="utf-8") as fp:
|
||||
raw_text = fp.read()
|
||||
|
||||
# Tokenise the complete text.
|
||||
token_list = _tokenize_text(tokenizer, raw_text)
|
||||
|
||||
if len(token_list) == 0:
|
||||
raise ValueError("Tokenisation of the raw file produced no tokens.")
|
||||
|
||||
# Convert to MX array (int32 is sufficient for token IDs).
|
||||
token_array = mx.array(token_list, dtype=mx.int32)
|
||||
|
||||
# Trim to a length that is an exact multiple of `sequence_length`.
|
||||
total_len = (token_array.shape[0] // sequence_length) * sequence_length
|
||||
token_array = token_array[:total_len]
|
||||
|
||||
# Reshape into (num_sequences, sequence_length)
|
||||
data = token_array.reshape(-1, sequence_length)
|
||||
|
||||
if num_samples > 0:
|
||||
data = data[:num_samples]
|
||||
|
||||
#print(f"First 4 samples of the data:")
|
||||
#for j in range(min(4, len(data))):
|
||||
# print(f" Sample {j}: {tokenizer.decode(data[j].tolist())}\n\n-------------------\n\n")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def eval_ppl(model, tokenizer, data, batch_size=8):
|
||||
"""
|
||||
Evaluate perplexity on a dataset with standard error calculation.
|
||||
|
||||
Args:
|
||||
model: The model to evaluate.
|
||||
data: Tokenized data tensor (shape: N x L).
|
||||
batch_size: Batch size for evaluation.
|
||||
|
||||
Returns:
|
||||
tuple: (perplexity, standard_error_of_perplexity)
|
||||
"""
|
||||
all_losses = []
|
||||
|
||||
num_batches = (len(data) + batch_size - 1) // batch_size
|
||||
for i, s in enumerate(range(0, len(data), batch_size)):
|
||||
batch = data[s : s + batch_size]
|
||||
|
||||
# Set the first token of all samples to the BOS token
|
||||
if tokenizer.bos_token_id:
|
||||
batch[:, 0] = tokenizer.bos_token_id
|
||||
|
||||
# compute cross entropy only with the second half of the sequence to match llama.cpp behavior
|
||||
# ref: https://github.com/ggml-org/llama.cpp/blob/696fccf354e9dbdfbce135bc40b44c9dcc64dda9/tools/perplexity/perplexity.cpp#L527-L541
|
||||
#
|
||||
#start = 0
|
||||
start = batch.shape[1] // 2
|
||||
|
||||
# Forward pass: get logits for all tokens except last
|
||||
logits = model(batch[:, :-1]).astype(mx.float32)
|
||||
|
||||
# Calculate cross‑entropy loss with next tokens
|
||||
#losses = nn.losses.cross_entropy(logits, batch[:, 1:], reduction="none")
|
||||
losses = nn.losses.cross_entropy(logits[:, start:, :], batch[:, start+1:], reduction="none")
|
||||
|
||||
mx.eval(losses)
|
||||
# Store individual token losses
|
||||
all_losses.append(losses.flatten())
|
||||
|
||||
# Progress indicator
|
||||
if (i + 1) % 1 == 0 or (i + 1) == num_batches:
|
||||
print(f" Processed {i + 1}/{num_batches} batches...", end="\r")
|
||||
|
||||
print() # New line after progress
|
||||
|
||||
# Concatenate all losses into a single array
|
||||
all_losses = mx.concatenate(all_losses)
|
||||
|
||||
# Calculate mean loss and perplexity
|
||||
mean_loss = all_losses.mean().item()
|
||||
ppl = math.exp(mean_loss)
|
||||
|
||||
# Calculate standard error
|
||||
std_dev = mx.sqrt(mx.var(all_losses, ddof=1)).item()
|
||||
num_tokens = all_losses.size
|
||||
standard_error = std_dev / math.sqrt(num_tokens)
|
||||
|
||||
# Delta approximation for standard error of perplexity
|
||||
standard_error_ppl = ppl * standard_error
|
||||
|
||||
return ppl, standard_error_ppl
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Evaluate perplexity of MLX models")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to model or Hugging Face model ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=8, help="Batch size for evaluation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sequence-length",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Sequence length for evaluation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Number of samples to use (-1 for all available)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-path",
|
||||
type=str,
|
||||
default="allenai/tulu-3-sft-mixture",
|
||||
help=(
|
||||
"A Hugging Face dataset compatible with mlx‑lm. "
|
||||
"Ignored if --raw-path is provided."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--raw-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Path to a local raw‑text file to use for evaluation. "
|
||||
"If specified, the script skips loading a HF dataset."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=123, help="Random seed for data sampling"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set random seed (used for HF dataset shuffling)
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
# Load model
|
||||
print(f"Loading model from {args.model}...")
|
||||
model, tokenizer = load(args.model)
|
||||
|
||||
# Count parameters
|
||||
total_params = get_total_parameters(model)
|
||||
print(f"Model loaded: {total_params/1e6:.1f}M parameters")
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Load evaluation data (raw file vs. HF dataset)
|
||||
# ----------------------------------------------------------------------
|
||||
print("\nLoading dataset...")
|
||||
print(f" Sequence length: {args.sequence_length}")
|
||||
|
||||
if args.raw_path:
|
||||
print(f" Using raw text file: {args.raw_path}")
|
||||
data = load_raw_data(
|
||||
tokenizer,
|
||||
raw_path=args.raw_path,
|
||||
num_samples=args.num_samples,
|
||||
sequence_length=args.sequence_length,
|
||||
)
|
||||
else:
|
||||
print(f" Using HF dataset: {args.data_path}")
|
||||
data = load_data(
|
||||
tokenizer,
|
||||
data_path=args.data_path,
|
||||
num_samples=args.num_samples,
|
||||
sequence_length=args.sequence_length,
|
||||
)
|
||||
|
||||
print(f" Loaded {len(data)} samples")
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Evaluate perplexity
|
||||
# ----------------------------------------------------------------------
|
||||
print(f"\nEvaluating perplexity with batch size {args.batch_size}...")
|
||||
start_time = time.time()
|
||||
|
||||
ppl, se = eval_ppl(model, tokenizer, data, batch_size=args.batch_size)
|
||||
|
||||
eval_time = time.time() - start_time
|
||||
tokens_evaluated = data.shape[0] * (data.shape[1] - 1) # B * (L - 1)
|
||||
|
||||
# Print results
|
||||
print("\n" + "=" * 60)
|
||||
print("EVALUATION RESULTS")
|
||||
print("=" * 60)
|
||||
print(f"Model: {args.model}")
|
||||
print(f"Perplexity: {ppl:.3f} ± {se:.3f}")
|
||||
print(f"Evaluation time: {eval_time:.2f} seconds")
|
||||
print(f"Peak memory: {mx.get_peak_memory() / 1e9:.2f} GB")
|
||||
print(f"Tokens per second: {tokens_evaluated / eval_time:.0f}")
|
||||
|
||||
# Additional statistics
|
||||
print(f"\nDataset statistics:")
|
||||
print(f" Total samples: {len(data)}")
|
||||
print(f" Total tokens: {data.size}")
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Done
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -50,6 +50,12 @@ int main(int argc, char ** argv) {
|
||||
const int N = 5; // n-gram size
|
||||
const int G = 15; // max verification n-grams
|
||||
|
||||
// lookahead requires W + G + 1 sequences for parallel Jacobi decoding
|
||||
params.n_parallel = W + G + 1;
|
||||
|
||||
// unified KV cache is required for coupled sequences in batch splitting
|
||||
params.kv_unified = true;
|
||||
|
||||
// init llama.cpp
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
@@ -115,7 +121,7 @@ int main(int argc, char ** argv) {
|
||||
// seq_id == 0 : the current input token
|
||||
// seq_id [1, W] : tokens from the past N - 1 Jacobi iterations
|
||||
// seq_id [W + 1, W + G] : verification n-grams
|
||||
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
|
||||
llama_batch batch = llama_batch_init(llama_n_ctx(ctx), 0, W + G + 1);
|
||||
|
||||
// target model sampling context
|
||||
struct common_sampler * smpl = common_sampler_init(model, params.sampling);
|
||||
|
||||
@@ -32,9 +32,9 @@ int main(int argc, char ** argv){
|
||||
|
||||
common_ngram_cache ngram_cache;
|
||||
common_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true);
|
||||
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str());
|
||||
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.speculative.lookup_cache_static.c_str());
|
||||
|
||||
common_ngram_cache_save(ngram_cache, params.lookup_cache_static);
|
||||
common_ngram_cache_save(ngram_cache, params.speculative.lookup_cache_static);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -46,18 +46,18 @@ int main(int argc, char ** argv){
|
||||
{
|
||||
const int64_t t_start_draft_us = ggml_time_us();
|
||||
|
||||
if (!params.lookup_cache_static.empty()) {
|
||||
if (!params.speculative.lookup_cache_static.empty()) {
|
||||
try {
|
||||
ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static);
|
||||
ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static);
|
||||
} catch (std::ifstream::failure const &) {
|
||||
LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
|
||||
LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str());
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
if (!params.lookup_cache_dynamic.empty()) {
|
||||
if (!params.speculative.lookup_cache_dynamic.empty()) {
|
||||
try {
|
||||
ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic);
|
||||
ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic);
|
||||
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
|
||||
}
|
||||
|
||||
|
||||
@@ -51,18 +51,18 @@ int main(int argc, char ** argv){
|
||||
const int64_t t_start_draft_us = ggml_time_us();
|
||||
common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false);
|
||||
|
||||
if (!params.lookup_cache_static.empty()) {
|
||||
if (!params.speculative.lookup_cache_static.empty()) {
|
||||
try {
|
||||
ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static);
|
||||
ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static);
|
||||
} catch (std::ifstream::failure const &) {
|
||||
LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
|
||||
LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str());
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
if (!params.lookup_cache_dynamic.empty()) {
|
||||
if (!params.speculative.lookup_cache_dynamic.empty()) {
|
||||
try {
|
||||
ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic);
|
||||
ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic);
|
||||
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ int main(int argc, char ** argv){
|
||||
|
||||
std::vector<llama_token> draft;
|
||||
|
||||
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1);
|
||||
llama_batch batch_tgt = llama_batch_init(llama_n_ctx(ctx), 0, 1);
|
||||
|
||||
const auto t_dec_start = ggml_time_us();
|
||||
|
||||
@@ -210,7 +210,7 @@ int main(int argc, char ** argv){
|
||||
|
||||
// Update dynamic ngram cache with context ngram cache and save it to disk:
|
||||
common_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context);
|
||||
common_ngram_cache_save(ngram_cache_dynamic, params.lookup_cache_dynamic);
|
||||
common_ngram_cache_save(ngram_cache_dynamic, params.speculative.lookup_cache_dynamic);
|
||||
|
||||
LOG("\n\n");
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
common_init();
|
||||
|
||||
if (params.speculative.model.path.empty()) {
|
||||
if (params.speculative.mparams_dft.path.empty()) {
|
||||
LOG_ERR("%s: --model-draft is required\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@@ -34,10 +34,8 @@ int main(int argc, char ** argv) {
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
llama_model * model_tgt = NULL;
|
||||
//llama_model * model_dft = NULL;
|
||||
|
||||
llama_context * ctx_tgt = NULL;
|
||||
llama_context * ctx_dft = NULL;
|
||||
|
||||
// load the target model
|
||||
auto llama_init_tgt = common_init_from_params(params);
|
||||
@@ -48,26 +46,38 @@ int main(int argc, char ** argv) {
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
|
||||
|
||||
// load the draft model
|
||||
params.devices = params.speculative.devices;
|
||||
params.model = params.speculative.model;
|
||||
params.n_ctx = params.speculative.n_ctx;
|
||||
params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch;
|
||||
params.n_gpu_layers = params.speculative.n_gpu_layers;
|
||||
llama_model_ptr model_dft;
|
||||
|
||||
if (params.speculative.cpuparams.n_threads > 0) {
|
||||
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
|
||||
}
|
||||
// TODO: simplify this logic
|
||||
{
|
||||
const auto & params_spec = params.speculative;
|
||||
|
||||
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
|
||||
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
|
||||
auto params_dft = params;
|
||||
|
||||
auto llama_init_dft = common_init_from_params(params);
|
||||
params_dft.n_parallel = 1;
|
||||
params_dft.n_ctx = params_spec.n_ctx;
|
||||
params_dft.n_batch = llama_n_ctx_seq(ctx_tgt);
|
||||
params_dft.devices = params_spec.devices;
|
||||
params_dft.model = params_spec.mparams_dft;
|
||||
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
|
||||
|
||||
//model_dft = llama_init_dft->model();
|
||||
ctx_dft = llama_init_dft->context();
|
||||
if (params_spec.cpuparams.n_threads > 0) {
|
||||
params_dft.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
|
||||
params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
|
||||
}
|
||||
|
||||
if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
|
||||
LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str());
|
||||
params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
|
||||
|
||||
auto mparams_dft = common_model_params_to_llama(params_dft);
|
||||
|
||||
model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
|
||||
if (model_dft == nullptr) {
|
||||
LOG_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
params.speculative.model_dft = model_dft.get();
|
||||
params.speculative.cparams_dft = common_context_params_to_llama(params_dft);
|
||||
}
|
||||
|
||||
// Tokenize the prompt
|
||||
@@ -92,12 +102,6 @@ int main(int argc, char ** argv) {
|
||||
LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
|
||||
}
|
||||
|
||||
// how many tokens to draft each time
|
||||
int n_draft = params.speculative.n_max;
|
||||
int n_draft_min = params.speculative.n_min;
|
||||
|
||||
float p_min = params.speculative.p_min;
|
||||
|
||||
int n_predict = 0;
|
||||
int n_drafted = 0;
|
||||
int n_accept = 0;
|
||||
@@ -127,15 +131,11 @@ int main(int argc, char ** argv) {
|
||||
int n_past = inp.size() - 1;
|
||||
|
||||
// init the speculator
|
||||
struct common_speculative_params params_spec;
|
||||
params_spec.n_draft = n_draft;
|
||||
params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
|
||||
params_spec.p_min = p_min;
|
||||
const auto & params_spec = params.speculative;
|
||||
|
||||
struct common_speculative * spec = common_speculative_init(ctx_tgt, ctx_dft);
|
||||
for (auto &pair : params.speculative.replacements) {
|
||||
common_speculative_add_replacement_tgt_dft(spec, pair.first.c_str(), pair.second.c_str());
|
||||
}
|
||||
struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt);
|
||||
|
||||
common_speculative_begin(spec, prompt_tgt);
|
||||
|
||||
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
|
||||
|
||||
@@ -151,7 +151,7 @@ int main(int argc, char ** argv) {
|
||||
// offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
|
||||
// from a cache or lookup tables.
|
||||
//
|
||||
llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last);
|
||||
llama_tokens draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last);
|
||||
|
||||
//LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
|
||||
|
||||
@@ -162,7 +162,7 @@ int main(int argc, char ** argv) {
|
||||
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
|
||||
{
|
||||
// do not waste time on small drafts
|
||||
if (draft.size() < (size_t) n_draft_min) {
|
||||
if (draft.size() < (size_t) params_spec.n_min) {
|
||||
draft.clear();
|
||||
}
|
||||
|
||||
@@ -240,7 +240,7 @@ int main(int argc, char ** argv) {
|
||||
LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
|
||||
|
||||
LOG_INF("\n");
|
||||
LOG_INF("n_draft = %d\n", n_draft);
|
||||
LOG_INF("n_draft = %d\n", params_spec.n_max);
|
||||
LOG_INF("n_predict = %d\n", n_predict);
|
||||
LOG_INF("n_drafted = %d\n", n_drafted);
|
||||
LOG_INF("n_accept = %d\n", n_accept);
|
||||
@@ -249,8 +249,6 @@ int main(int argc, char ** argv) {
|
||||
LOG_INF("\n");
|
||||
LOG_INF("draft:\n\n");
|
||||
|
||||
llama_perf_context_print(ctx_dft);
|
||||
|
||||
LOG_INF("\n");
|
||||
LOG_INF("target:\n\n");
|
||||
common_perf_print(ctx_tgt, smpl);
|
||||
|
||||
@@ -46,7 +46,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
common_init();
|
||||
|
||||
if (params.speculative.model.path.empty()) {
|
||||
if (params.speculative.mparams_dft.path.empty()) {
|
||||
LOG_ERR("%s: --model-draft is required\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@@ -78,7 +78,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// load the draft model
|
||||
params.devices = params.speculative.devices;
|
||||
params.model = params.speculative.model;
|
||||
params.model = params.speculative.mparams_dft;
|
||||
params.n_gpu_layers = params.speculative.n_gpu_layers;
|
||||
if (params.speculative.cpuparams.n_threads > 0) {
|
||||
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
|
||||
|
||||
@@ -228,6 +228,8 @@ option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU)
|
||||
option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF)
|
||||
option(GGML_WEBGPU_JSPI "ggml: use JSPI for WebGPU" ON)
|
||||
option(GGML_ZDNN "ggml: use zDNN" OFF)
|
||||
option(GGML_VIRTGPU "ggml: use the VirtGPU/Virglrenderer API Remoting frontend" OFF)
|
||||
option(GGML_VIRTGPU_BACKEND "ggml: build the VirtGPU/Virglrenderer API Remoting backend" OFF)
|
||||
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
|
||||
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
|
||||
option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF)
|
||||
@@ -320,6 +322,7 @@ set(GGML_PUBLIC_HEADERS
|
||||
include/ggml-opt.h
|
||||
include/ggml-metal.h
|
||||
include/ggml-rpc.h
|
||||
include/ggml-virtgpu.h
|
||||
include/ggml-sycl.h
|
||||
include/ggml-vulkan.h
|
||||
include/ggml-webgpu.h
|
||||
|
||||
16
ggml/include/ggml-virtgpu.h
Normal file
16
ggml/include/ggml-virtgpu.h
Normal file
@@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-backend.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define GGML_REMOTING_FRONTEND_NAME "RemotingFrontend"
|
||||
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_virtgpu_reg();
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -222,6 +222,7 @@ if (GGML_SCHED_NO_REALLOC)
|
||||
endif()
|
||||
|
||||
add_library(ggml
|
||||
ggml-backend-dl.cpp
|
||||
ggml-backend-reg.cpp)
|
||||
add_library(ggml::ggml ALIAS ggml)
|
||||
|
||||
@@ -451,6 +452,7 @@ ggml_add_backend(HIP)
|
||||
ggml_add_backend(METAL)
|
||||
ggml_add_backend(MUSA)
|
||||
ggml_add_backend(RPC)
|
||||
ggml_add_backend(VirtGPU)
|
||||
ggml_add_backend(SYCL)
|
||||
ggml_add_backend(Vulkan)
|
||||
ggml_add_backend(WebGPU)
|
||||
|
||||
48
ggml/src/ggml-backend-dl.cpp
Normal file
48
ggml/src/ggml-backend-dl.cpp
Normal file
@@ -0,0 +1,48 @@
|
||||
#include "ggml-backend-dl.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
dl_handle * dl_load_library(const fs::path & path) {
|
||||
// suppress error dialogs for missing DLLs
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
|
||||
HMODULE handle = LoadLibraryW(path.wstring().c_str());
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
|
||||
void * p = (void *) GetProcAddress(handle, name);
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
return p;
|
||||
}
|
||||
|
||||
const char * dl_error() {
|
||||
return "";
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
dl_handle * dl_load_library(const fs::path & path) {
|
||||
dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
|
||||
return handle;
|
||||
}
|
||||
|
||||
void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
return dlsym(handle, name);
|
||||
}
|
||||
|
||||
const char * dl_error() {
|
||||
const char *rslt = dlerror();
|
||||
return rslt != nullptr ? rslt : "";
|
||||
}
|
||||
|
||||
#endif
|
||||
45
ggml/src/ggml-backend-dl.h
Normal file
45
ggml/src/ggml-backend-dl.h
Normal file
@@ -0,0 +1,45 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef _WIN32
|
||||
# define WIN32_LEAN_AND_MEAN
|
||||
# ifndef NOMINMAX
|
||||
# define NOMINMAX
|
||||
# endif
|
||||
# include <windows.h>
|
||||
# include <winevt.h>
|
||||
#else
|
||||
# include <dlfcn.h>
|
||||
# include <unistd.h>
|
||||
#endif
|
||||
#include <filesystem>
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
using dl_handle = std::remove_pointer_t<HMODULE>;
|
||||
|
||||
struct dl_handle_deleter {
|
||||
void operator()(HMODULE handle) {
|
||||
FreeLibrary(handle);
|
||||
}
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
using dl_handle = void;
|
||||
|
||||
struct dl_handle_deleter {
|
||||
void operator()(void * handle) {
|
||||
dlclose(handle);
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
|
||||
|
||||
dl_handle * dl_load_library(const fs::path & path);
|
||||
void * dl_get_sym(dl_handle * handle, const char * name);
|
||||
const char * dl_error();
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-backend-dl.h"
|
||||
#include "ggml-impl.h"
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
@@ -69,6 +70,10 @@
|
||||
#include "ggml-rpc.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_VIRTGPU_FRONTEND
|
||||
#include "ggml-virtgpu.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_CANN
|
||||
#include "ggml-cann.h"
|
||||
#endif
|
||||
@@ -94,72 +99,6 @@ static std::string path_str(const fs::path & path) {
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
using dl_handle = std::remove_pointer_t<HMODULE>;
|
||||
|
||||
struct dl_handle_deleter {
|
||||
void operator()(HMODULE handle) {
|
||||
FreeLibrary(handle);
|
||||
}
|
||||
};
|
||||
|
||||
static dl_handle * dl_load_library(const fs::path & path) {
|
||||
// suppress error dialogs for missing DLLs
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
|
||||
HMODULE handle = LoadLibraryW(path.wstring().c_str());
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
static void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
|
||||
void * p = (void *) GetProcAddress(handle, name);
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
return p;
|
||||
}
|
||||
|
||||
static const char * dl_error() {
|
||||
return "";
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
using dl_handle = void;
|
||||
|
||||
struct dl_handle_deleter {
|
||||
void operator()(void * handle) {
|
||||
dlclose(handle);
|
||||
}
|
||||
};
|
||||
|
||||
static void * dl_load_library(const fs::path & path) {
|
||||
dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
static void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
return dlsym(handle, name);
|
||||
}
|
||||
|
||||
static const char * dl_error() {
|
||||
const char *rslt = dlerror();
|
||||
return rslt != nullptr ? rslt : "";
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
|
||||
|
||||
struct ggml_backend_reg_entry {
|
||||
ggml_backend_reg_t reg;
|
||||
dl_handle_ptr handle;
|
||||
@@ -180,7 +119,12 @@ struct ggml_backend_registry {
|
||||
register_backend(ggml_backend_sycl_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_VULKAN
|
||||
// Add runtime disable check
|
||||
if (getenv("GGML_DISABLE_VULKAN") == nullptr) {
|
||||
register_backend(ggml_backend_vk_reg());
|
||||
} else {
|
||||
GGML_LOG_DEBUG("Vulkan backend disabled by GGML_DISABLE_VULKAN environment variable\n");
|
||||
}
|
||||
#endif
|
||||
#ifdef GGML_USE_WEBGPU
|
||||
register_backend(ggml_backend_webgpu_reg());
|
||||
@@ -188,6 +132,10 @@ struct ggml_backend_registry {
|
||||
#ifdef GGML_USE_ZDNN
|
||||
register_backend(ggml_backend_zdnn_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_VIRTGPU_FRONTEND
|
||||
register_backend(ggml_backend_virtgpu_reg());
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_OPENCL
|
||||
register_backend(ggml_backend_opencl_reg());
|
||||
#endif
|
||||
@@ -604,6 +552,7 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||
ggml_backend_load_best("rpc", silent, dir_path);
|
||||
ggml_backend_load_best("sycl", silent, dir_path);
|
||||
ggml_backend_load_best("vulkan", silent, dir_path);
|
||||
ggml_backend_load_best("virtgpu", silent, dir_path);
|
||||
ggml_backend_load_best("opencl", silent, dir_path);
|
||||
ggml_backend_load_best("hexagon", silent, dir_path);
|
||||
ggml_backend_load_best("musa", silent, dir_path);
|
||||
|
||||
@@ -3148,16 +3148,17 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
||||
|
||||
// Scales[i] corresponds to column i
|
||||
const int scale_offset = cp * 2;
|
||||
for (int blk = 0; blk < 2; blk++) {
|
||||
const int32x4_t block_scale = {
|
||||
(int32_t) q4sb_scales[blk][scale_offset],
|
||||
(int32_t) q4sb_scales[blk][scale_offset],
|
||||
(int32_t) q4sb_scales[blk][scale_offset + 1],
|
||||
(int32_t) q4sb_scales[blk][scale_offset + 1],
|
||||
};
|
||||
acc[cp] = vmlaq_s32(acc[cp], sb_acc[blk], block_scale);
|
||||
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale);
|
||||
}
|
||||
const int32_t scale_00 = q4sb_scales[0][scale_offset];
|
||||
const int32_t scale_01 = q4sb_scales[0][scale_offset + 1];
|
||||
const int32_t scale_10 = q4sb_scales[1][scale_offset];
|
||||
const int32_t scale_11 = q4sb_scales[1][scale_offset + 1];
|
||||
const int32x4_t block_scale_0 = vcombine_s32(vdup_n_s32(scale_00), vdup_n_s32(scale_01));
|
||||
const int32x4_t block_scale_1 = vcombine_s32(vdup_n_s32(scale_10), vdup_n_s32(scale_11));
|
||||
|
||||
acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale_0);
|
||||
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale_0);
|
||||
acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale_1);
|
||||
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale_1);
|
||||
}
|
||||
|
||||
// Multiply Acc bsum + mins
|
||||
|
||||
@@ -53,6 +53,7 @@
|
||||
// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see
|
||||
// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms
|
||||
#define GGML_CUDA_CC_BLACKWELL 1200
|
||||
#define GGML_CUDA_CC_DGX_SPARK 1210
|
||||
#define GGML_CUDA_CC_RUBIN 1300
|
||||
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
|
||||
#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
|
||||
@@ -1121,15 +1122,18 @@ struct ggml_tensor_extra_gpu {
|
||||
#endif
|
||||
|
||||
struct ggml_cuda_graph_node_properties {
|
||||
void * node_address;
|
||||
void * node_data;
|
||||
ggml_op node_op;
|
||||
enum ggml_type node_type;
|
||||
int32_t flags;
|
||||
int64_t ne[GGML_MAX_DIMS];
|
||||
size_t nb[GGML_MAX_DIMS];
|
||||
void * src_address[GGML_MAX_SRC];
|
||||
void * src_data[GGML_MAX_SRC];
|
||||
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||
};
|
||||
|
||||
static_assert(std::is_trivial<ggml_cuda_graph_node_properties>::value, "ggml_cuda_graph_node_properties must be trivial");
|
||||
|
||||
struct ggml_cuda_graph {
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
~ggml_cuda_graph() {
|
||||
@@ -1149,6 +1153,12 @@ struct ggml_cuda_graph {
|
||||
int number_consecutive_updates = 0;
|
||||
std::vector<ggml_cuda_graph_node_properties> props;
|
||||
|
||||
// these are extra tensors (inputs) that participate in the ggml graph but are not nodes
|
||||
// they properties also have to match in order to be able to safely reuse a CUDA graph
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18583
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/19165
|
||||
std::vector<ggml_cuda_graph_node_properties> extra;
|
||||
|
||||
void record_update(bool use_graph, bool update_required) {
|
||||
if (use_graph && update_required) {
|
||||
number_consecutive_updates++;
|
||||
|
||||
@@ -789,7 +789,7 @@ void launch_fattn(
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
const bool V_is_K_view = V->view_src && V->view_offs == 0 && (V->view_src == K || V->view_src == K->view_src);
|
||||
const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * sinks = dst->src[4];
|
||||
|
||||
@@ -147,6 +147,14 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
if (gqa_ratio == 20) { // GLM 4.7 Flash
|
||||
if (cc >= GGML_CUDA_CC_DGX_SPARK) {
|
||||
if (Q->ne[1] <= 8) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
break;
|
||||
}
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
|
||||
break;
|
||||
}
|
||||
if (cc >= GGML_CUDA_CC_BLACKWELL) {
|
||||
if (Q->ne[1] <= 4 && K->ne[1] >= 65536) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
@@ -302,8 +310,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
}
|
||||
}
|
||||
|
||||
const bool V_is_K_view = V->view_src && V->view_offs == 0 && (V->view_src == K || V->view_src == K->view_src);
|
||||
|
||||
const int cc = ggml_cuda_info().devices[device].cc;
|
||||
|
||||
switch (K->ne[0]) {
|
||||
@@ -326,9 +332,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
if (!gqa_opt_applies) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!V_is_K_view) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
|
||||
@@ -70,17 +70,18 @@
|
||||
#include <condition_variable>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <float.h>
|
||||
#include <cfloat>
|
||||
#include <initializer_list>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <stdarg.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <cstdarg>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
|
||||
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||
|
||||
@@ -2916,22 +2917,27 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
}
|
||||
|
||||
static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
|
||||
props->node_address = node->data;
|
||||
memset(props, 0, sizeof(ggml_cuda_graph_node_properties));
|
||||
props->node_data = node->data;
|
||||
props->node_op = node->op;
|
||||
props->node_type = node->type;
|
||||
props->flags = node->flags;
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
props->ne[i] = node->ne[i];
|
||||
props->nb[i] = node->nb[i];
|
||||
}
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
|
||||
if (!node->src[i]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
props->src_data[i] = node->src[i]->data;
|
||||
}
|
||||
memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
|
||||
}
|
||||
|
||||
static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
|
||||
if (node->data != props->node_address &&
|
||||
node->op != GGML_OP_VIEW) {
|
||||
if (node->data != props->node_data && node->op != GGML_OP_VIEW) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -2939,6 +2945,10 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node->type != props->node_type) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (node->ne[i] != props->ne[i]) {
|
||||
return false;
|
||||
@@ -2948,12 +2958,18 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (node->src[i] &&
|
||||
node->src[i]->data != props->src_address[i] &&
|
||||
node->op != GGML_OP_VIEW
|
||||
) {
|
||||
return false;
|
||||
if (node->op != GGML_OP_VIEW) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (!node->src[i]) {
|
||||
if (props->src_data[i] != nullptr) {
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node->src[i]->data != props->src_data[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2974,7 +2990,6 @@ static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
|
||||
}
|
||||
|
||||
static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
|
||||
|
||||
bool res = false;
|
||||
|
||||
const void * graph_key = ggml_cuda_graph_get_key(cgraph);
|
||||
@@ -2985,15 +3000,20 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
|
||||
}
|
||||
|
||||
// Check if the graph size has changed
|
||||
if (graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
|
||||
if (graph->props.size() != (size_t)cgraph->n_nodes) {
|
||||
res = true;
|
||||
graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
|
||||
graph->props.resize(cgraph->n_nodes);
|
||||
}
|
||||
|
||||
// Loop over nodes in GGML graph to determine if CUDA graph update is required
|
||||
// and store properties to allow this comparison for the next token
|
||||
std::unordered_set<ggml_tensor *> seen_node;
|
||||
std::vector<ggml_tensor *> srcs_extra;
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
bool props_match = true;
|
||||
|
||||
seen_node.insert(cgraph->nodes[i]);
|
||||
|
||||
if (!res) {
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
|
||||
}
|
||||
@@ -3001,17 +3021,31 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
|
||||
res = true;
|
||||
}
|
||||
ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
|
||||
|
||||
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
|
||||
ggml_tensor * src = cgraph->nodes[i]->src[src_idx];
|
||||
if (src && seen_node.find(src) == seen_node.end()) {
|
||||
srcs_extra.push_back(src);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < cgraph->n_leafs; i++) {
|
||||
if (graph->extra.size() != (size_t) srcs_extra.size()) {
|
||||
res = true;
|
||||
graph->extra.resize(srcs_extra.size());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < srcs_extra.size(); ++i) {
|
||||
bool props_match = true;
|
||||
|
||||
if (!res) {
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &graph->props[cgraph->n_nodes + i]);
|
||||
props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]);
|
||||
}
|
||||
|
||||
if (!props_match) {
|
||||
res = true;
|
||||
}
|
||||
ggml_cuda_graph_node_set_properties(&graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
|
||||
ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]);
|
||||
}
|
||||
|
||||
return res;
|
||||
@@ -3080,63 +3114,166 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
|
||||
static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {
|
||||
args.sigmoid = false;
|
||||
args.softmax = false;
|
||||
args.delayed_softmax = false;
|
||||
args.prob_bias = false;
|
||||
args.norm = false;
|
||||
|
||||
const int n_nodes = cgraph->n_nodes;
|
||||
ggml_tensor ** nodes = cgraph->nodes;
|
||||
|
||||
if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) {
|
||||
args.softmax = true;
|
||||
}
|
||||
|
||||
if (nodes[node_idx]->op == GGML_OP_UNARY) {
|
||||
if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) {
|
||||
return false;
|
||||
}
|
||||
args.sigmoid = true;
|
||||
}
|
||||
|
||||
if (nodes[node_idx]->op == GGML_OP_ARGSORT) {
|
||||
args.delayed_softmax = true;
|
||||
}
|
||||
|
||||
node_idx++;
|
||||
|
||||
if (args.sigmoid || args.softmax) {
|
||||
// SOFTMAX -> RESHAPE
|
||||
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE ||
|
||||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
|
||||
return false;
|
||||
}
|
||||
ggml_tensor * probs_reshaped = nodes[node_idx];
|
||||
node_idx++;
|
||||
|
||||
if (node_idx >= n_nodes) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// src of bias add is the unreshaped probs (-2 instead of -1)
|
||||
if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) {
|
||||
args.prob_bias = true;
|
||||
node_idx++;
|
||||
}
|
||||
// RESHAPE/ADD -> ARGSORT
|
||||
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
|
||||
return false;
|
||||
} else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
node_idx++;
|
||||
|
||||
// ARGSORT-> VIEW
|
||||
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
|
||||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
|
||||
return false;
|
||||
}
|
||||
node_idx++;
|
||||
|
||||
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// GET_ROWS
|
||||
if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) {
|
||||
return false;
|
||||
}
|
||||
node_idx++;
|
||||
} else if (args.delayed_softmax) {
|
||||
if (node_idx - 2 < 0) {
|
||||
return false;
|
||||
}
|
||||
ggml_tensor * probs_reshaped = nodes[node_idx - 2];
|
||||
|
||||
// VIEW->ARGSORT
|
||||
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
|
||||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
|
||||
return false;
|
||||
}
|
||||
node_idx++;
|
||||
|
||||
// GET_ROWS
|
||||
if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
|
||||
nodes[node_idx]->src[0] != probs_reshaped) {
|
||||
return false;
|
||||
}
|
||||
node_idx++;
|
||||
|
||||
static const std::vector<ggml_op> remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
|
||||
|
||||
for (const ggml_op op : remaining_ops) {
|
||||
if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
|
||||
return false;
|
||||
}
|
||||
node_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
// At this point we can check for norm + scale. Everything is now at least valid till the norm
|
||||
if (node_idx >= n_nodes) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (nodes[node_idx]->op == GGML_OP_RESHAPE) {
|
||||
//check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE
|
||||
static const std::vector<ggml_op> norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP };
|
||||
|
||||
args.norm = true;
|
||||
for (const ggml_op op : norm_ops) {
|
||||
if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
|
||||
node_idx++;
|
||||
} else {
|
||||
args.norm = false;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// DIV <- CLAMP, RESHAPE
|
||||
if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
|
||||
nodes[node_idx]->src[0] != nodes[node_idx - 3]) {
|
||||
args.norm = false;
|
||||
return true;
|
||||
}
|
||||
node_idx++;
|
||||
|
||||
if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
|
||||
args.norm = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
node_idx++;
|
||||
}
|
||||
|
||||
if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
|
||||
args.scale = true;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
|
||||
int node_idx,
|
||||
std::initializer_list<enum ggml_op> ops,
|
||||
std::initializer_list<enum ggml_unary_op> unary_ops) {
|
||||
#ifndef NDEBUG
|
||||
const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
|
||||
GGML_ASSERT(unary_ops.size() == num_unary);
|
||||
#endif
|
||||
|
||||
//TODO: remove special case once ggml_can_fuse can handle empty nodes
|
||||
std::initializer_list<enum ggml_op> topk_moe_ops =
|
||||
ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
|
||||
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
|
||||
ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
|
||||
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
|
||||
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
|
||||
|
||||
const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
|
||||
const std::initializer_list<enum ggml_op> & list2) {
|
||||
return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
|
||||
};
|
||||
|
||||
if (is_equal(topk_moe_ops_with_norm, ops) &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
|
||||
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
|
||||
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
|
||||
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
|
||||
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
|
||||
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
|
||||
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
|
||||
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
|
||||
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
|
||||
ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
|
||||
ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
|
||||
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
|
||||
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
|
||||
std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
|
||||
|
||||
@@ -3398,35 +3535,75 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
||||
// start of fusion operations
|
||||
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
|
||||
if (!disable_fusion) {
|
||||
ggml_cuda_topk_moe_args args;
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
|
||||
ggml_tensor * weights = cgraph->nodes[i + 9];
|
||||
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
|
||||
ggml_tensor * clamp = cgraph->nodes[i + 7];
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
|
||||
/*delayed softmax*/ false, clamp);
|
||||
i += 9;
|
||||
continue;
|
||||
}
|
||||
if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
|
||||
cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
|
||||
const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
|
||||
ggml_tensor * weights = cgraph->nodes[i + 4];
|
||||
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
|
||||
/*delayed softmax*/ false);
|
||||
i += 4;
|
||||
continue;
|
||||
}
|
||||
std::vector<ggml_op> ops;
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i,
|
||||
ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
|
||||
ggml_tensor * weights = cgraph->nodes[i + 5];
|
||||
ggml_tensor * ids = cgraph->nodes[i + 1];
|
||||
if (can_fuse) {
|
||||
const ggml_tensor * logits = node->src[0];
|
||||
ggml_tensor * weights = nullptr;
|
||||
ggml_tensor * ids = nullptr;
|
||||
const ggml_tensor * bias = nullptr;
|
||||
const ggml_tensor * clamp = nullptr;
|
||||
const ggml_tensor * scale = nullptr;
|
||||
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
|
||||
/*delayed_softmax*/ true);
|
||||
i += 5;
|
||||
continue;
|
||||
if (!args.delayed_softmax) {
|
||||
ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
|
||||
int out_nodes[2]; // nodes which can't be elided
|
||||
|
||||
if (args.prob_bias) {
|
||||
bias = cgraph->nodes[i + 2]->src[1];
|
||||
ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS });
|
||||
out_nodes[0] = i + 4;
|
||||
ids = cgraph->nodes[i + 4];
|
||||
} else {
|
||||
ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW,
|
||||
GGML_OP_GET_ROWS });
|
||||
out_nodes[0] = i + 3;
|
||||
ids = cgraph->nodes[i + 3];
|
||||
}
|
||||
|
||||
if (args.norm) {
|
||||
ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
|
||||
GGML_OP_DIV, GGML_OP_RESHAPE });
|
||||
clamp = cgraph->nodes[i + ops.size() - 3];
|
||||
}
|
||||
if (args.scale) {
|
||||
ops.insert(ops.end(), { GGML_OP_SCALE });
|
||||
scale = cgraph->nodes[i + ops.size() - 1];
|
||||
}
|
||||
|
||||
weights = cgraph->nodes[i + ops.size() - 1];
|
||||
out_nodes[1] = i + ops.size() - 1;
|
||||
|
||||
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
|
||||
ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) {
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
|
||||
i += ops.size() - 1;
|
||||
continue;
|
||||
}
|
||||
} else if (!args.norm && !args.prob_bias) {
|
||||
//special case gpt-oss, no norm, no bias.
|
||||
ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
|
||||
GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
|
||||
weights = cgraph->nodes[i + 5];
|
||||
ids = cgraph->nodes[i + 1];
|
||||
const ggml_tensor * softmax = cgraph->nodes[i + 4];
|
||||
|
||||
int out_nodes[2] = { i + 1, i + 5 };
|
||||
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
|
||||
ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) {
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
|
||||
i += ops.size() - 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
|
||||
@@ -3733,14 +3910,14 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
||||
// Launch graph
|
||||
CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream()));
|
||||
#else
|
||||
GGML_UNUSED(graph_key);
|
||||
graph_evaluated_or_captured = true;
|
||||
#endif // USE_CUDA_GRAPH
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
|
||||
if (graph->graph == nullptr) {
|
||||
@@ -3753,12 +3930,8 @@ static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, co
|
||||
}
|
||||
|
||||
return graph->is_enabled();
|
||||
#else
|
||||
GGML_UNUSED(cuda_ctx);
|
||||
GGML_UNUSED(graph_key);
|
||||
return false;
|
||||
#endif // USE_CUDA_GRAPH
|
||||
}
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
|
||||
|
||||
@@ -333,7 +333,33 @@ namespace ggml_cuda_mma {
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return 4 * (threadIdx.x / 16) + l;
|
||||
return ne * (threadIdx.x / 16) + l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
static constexpr int ne = I * J / 64;
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 8) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return threadIdx.x % 16;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return ne * (threadIdx.x / 16) + l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
@@ -391,7 +417,22 @@ namespace ggml_cuda_mma {
|
||||
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
||||
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
static constexpr int ne = I * J / 32;
|
||||
static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
|
||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
|
||||
}
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
|
||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
@@ -945,6 +986,32 @@ namespace ggml_cuda_mma {
|
||||
#endif // AMPERE_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template <data_layout dl_ab, data_layout dl_d>
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) {
|
||||
#ifdef AMD_MFMA_AVAILABLE
|
||||
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
||||
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
|
||||
#if defined(CDNA3)
|
||||
using floatx2_t = __attribute__((ext_vector_type(2))) float;
|
||||
const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[0]);
|
||||
const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[0]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
#elif defined(CDNA2) || defined(CDNA1)
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(CDNA3)
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // AMD_MFMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
|
||||
const tile<16, 8, int> & A,
|
||||
const tile<8, 8, int> & B,
|
||||
@@ -1054,6 +1121,13 @@ namespace ggml_cuda_mma {
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // RDNA4
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
|
||||
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
||||
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
|
||||
const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);
|
||||
const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
@@ -1081,11 +1155,31 @@ namespace ggml_cuda_mma {
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // RDNA4
|
||||
#endif // defined(RDNA4)
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
||||
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
|
||||
#if defined(CDNA3) || defined(CDNA2)
|
||||
using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16;
|
||||
const bf16x4_t& a_frag = reinterpret_cast<const bf16x4_t&>(A.x[0]);
|
||||
const bf16x4_t& b_frag = reinterpret_cast<const bf16x4_t&>(B.x[0]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
#elif defined(CDNA1)
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16;
|
||||
const bf16x2_t& a_frag = reinterpret_cast<const bf16x2_t&>(A.x[i]);
|
||||
const bf16x2_t& b_frag = reinterpret_cast<const bf16x2_t&>(B.x[i]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // AMPERE_MMA_AVAILABLE
|
||||
#endif // defined(CDNA3) || defined(CDNA2)
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template <data_layout dl_d, data_layout dl_ab>
|
||||
|
||||
@@ -2,6 +2,13 @@
|
||||
#include "mmf.cuh"
|
||||
#include "mmid.cuh"
|
||||
|
||||
static __forceinline__ int mmf_get_rows_per_block(const int cc) {
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
return MMF_ROWS_PER_BLOCK_CDNA;
|
||||
} else {
|
||||
return MMF_ROWS_PER_BLOCK;
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||
@@ -89,28 +96,32 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
||||
ids_info_ptr = &ids_info;
|
||||
}
|
||||
|
||||
const int device = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[device].cc;
|
||||
const int rows_per_block = mmf_get_rows_per_block(cc);
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
constexpr int vals_per_T = 1;
|
||||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
mul_mat_f_switch_rows_per_block<float>(
|
||||
rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
const half2 * src0_d = (const half2 *) src0->data;
|
||||
constexpr int vals_per_T = 2;
|
||||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
mul_mat_f_switch_rows_per_block<half2>(
|
||||
rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
|
||||
constexpr int vals_per_T = 2;
|
||||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
mul_mat_f_switch_rows_per_block<nv_bfloat162>(
|
||||
rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} break;
|
||||
@@ -140,7 +151,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
|
||||
if (src0_ne[1] % mmf_get_rows_per_block(cc) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (GGML_CUDA_CC_IS_CDNA3(cc) && type == GGML_TYPE_BF16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -153,6 +168,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||
} else {
|
||||
if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) {
|
||||
return false;
|
||||
} else if (GGML_CUDA_CC_IS_CDNA2(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
|
||||
//TODO: truse CDNA2 as CDNA1, tune the perf when CDNA2 is available.
|
||||
return false;
|
||||
} else if (GGML_CUDA_CC_IS_CDNA1(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
|
||||
return false;
|
||||
} else if (src1_ncols > 16) {
|
||||
return false;
|
||||
}
|
||||
@@ -160,11 +180,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||
|
||||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
return ampere_mma_available(cc);
|
||||
return ampere_mma_available(cc) || amd_mfma_available(cc);
|
||||
case GGML_TYPE_F16:
|
||||
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
|
||||
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
|
||||
case GGML_TYPE_BF16:
|
||||
return ampere_mma_available(cc) || amd_wmma_available(cc);
|
||||
return ampere_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -7,6 +7,31 @@
|
||||
using namespace ggml_cuda_mma;
|
||||
|
||||
#define MMF_ROWS_PER_BLOCK 32
|
||||
#define MMF_ROWS_PER_BLOCK_CDNA 64
|
||||
|
||||
static __forceinline__ int64_t mmf_get_max_block_size(int cc) {
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
return 512;
|
||||
} else {
|
||||
return 256;
|
||||
}
|
||||
}
|
||||
|
||||
static __forceinline__ int mmf_get_padding(int cc) {
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
return 2;
|
||||
} else {
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr __device__ int mmf_get_padding() {
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
return 2;
|
||||
#else
|
||||
return 4;
|
||||
#endif // defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
struct mmf_ids_data {
|
||||
const int32_t * ids_src_compact = nullptr;
|
||||
@@ -29,23 +54,25 @@ static __global__ void mul_mat_f(
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
|
||||
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
||||
constexpr bool is_tf32 = std::is_same_v<T, float>;
|
||||
constexpr int tile_B_I = is_tf32 ? 8 : 16;
|
||||
constexpr int tile_C_J = is_tf32 ? 8 : 16;
|
||||
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
|
||||
typedef tile<16, 8, T, ab_layout> tile_A;
|
||||
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
|
||||
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T, get_input_data_layout()> tile_A;
|
||||
typedef tile<16, 8, T, get_input_data_layout()> tile_B;
|
||||
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
#else
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
|
||||
if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
|
||||
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
|
||||
#else
|
||||
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
@@ -57,7 +84,7 @@ static __global__ void mul_mat_f(
|
||||
}
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
constexpr int tile_k_padded = warp_size + mmf_get_padding();
|
||||
constexpr int ntA = rows_per_block / tile_A::I;
|
||||
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
|
||||
|
||||
@@ -198,7 +225,7 @@ static __global__ void mul_mat_f(
|
||||
}
|
||||
|
||||
float * buf_iw = (float *) compute_base;
|
||||
constexpr int kiw = nwarps*rows_per_block + 4;
|
||||
constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
|
||||
|
||||
if (nwarps > 1) {
|
||||
__syncthreads();
|
||||
@@ -228,27 +255,34 @@ static __global__ void mul_mat_f(
|
||||
return;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
static_assert(rows_per_block == warp_size, "need loop/check");
|
||||
float sum[rows_per_block/warp_size] = {0.0f};
|
||||
static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
|
||||
const int i = i0 + i1*warp_size + threadIdx.x;
|
||||
|
||||
sum += buf_iw[j*kiw + i];
|
||||
sum[i1] += buf_iw[j*kiw + i];
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (!has_ids) {
|
||||
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
|
||||
dst[j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
|
||||
}
|
||||
} else {
|
||||
const int slot = (j < cols_per_block) ? slot_map[j] : -1;
|
||||
if (slot >= 0 && (col_base + j) < ncols_dst_total) {
|
||||
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
|
||||
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
}
|
||||
#endif //VOLTA_MMA_AVAILABLE
|
||||
#else
|
||||
GGML_UNUSED_VARS(x, y, ids, dst,
|
||||
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
@@ -256,7 +290,7 @@ static __global__ void mul_mat_f(
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
//This kernel is for larger batch sizes of mul_mat_id
|
||||
@@ -271,23 +305,25 @@ static __global__ void mul_mat_f_ids(
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const uint3 sis1_fd, const uint3 nch_fd) {
|
||||
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
|
||||
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
||||
constexpr bool is_tf32 = std::is_same_v<T, float>;
|
||||
constexpr int tile_B_I = is_tf32 ? 8 : 16;
|
||||
constexpr int tile_C_J = is_tf32 ? 8 : 16;
|
||||
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
|
||||
typedef tile<16, 8, T, ab_layout> tile_A;
|
||||
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
|
||||
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T, get_input_data_layout()> tile_A;
|
||||
typedef tile<16, 8, T, get_input_data_layout()> tile_B;
|
||||
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
#else
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
|
||||
if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
|
||||
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
|
||||
#else
|
||||
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
@@ -300,7 +336,7 @@ static __global__ void mul_mat_f_ids(
|
||||
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
constexpr int tile_k_padded = warp_size + mmf_get_padding();
|
||||
constexpr int ntA = rows_per_block / tile_A::I;
|
||||
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
|
||||
|
||||
@@ -467,7 +503,7 @@ static __global__ void mul_mat_f_ids(
|
||||
}
|
||||
|
||||
float * buf_iw = (float *) compute_base;
|
||||
constexpr int kiw = nwarps*rows_per_block + 4;
|
||||
constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
|
||||
|
||||
if (nwarps > 1) {
|
||||
__syncthreads();
|
||||
@@ -497,13 +533,16 @@ static __global__ void mul_mat_f_ids(
|
||||
return;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
static_assert(rows_per_block == warp_size, "need loop/check");
|
||||
float sum[rows_per_block/warp_size] = {0.0f};
|
||||
static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
|
||||
const int i = i0 + i1*warp_size + threadIdx.x;
|
||||
|
||||
sum += buf_iw[j*kiw + i];
|
||||
sum[i1] += buf_iw[j * kiw + i];
|
||||
}
|
||||
}
|
||||
|
||||
const int global_j = col_base + j;
|
||||
@@ -513,23 +552,24 @@ static __global__ void mul_mat_f_ids(
|
||||
const int token = (int) qrm.x;
|
||||
if (token < ncols_dst_total) {
|
||||
const int slot = (int) qrm.y;
|
||||
dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
|
||||
dst[slot * stride_channel_dst + token * stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
}
|
||||
#endif // VOLTA_MMA_AVAILABLE
|
||||
#else
|
||||
GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
|
||||
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template<typename T, int cols_per_block, int nwarps>
|
||||
template<typename T, int rows_per_block, int cols_per_block, int nwarps>
|
||||
static inline void mul_mat_f_switch_ids(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
|
||||
@@ -553,7 +593,7 @@ static inline void mul_mat_f_switch_ids(
|
||||
const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
|
||||
const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
|
||||
|
||||
mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||
mul_mat_f_ids<T, rows_per_block, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
|
||||
ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
@@ -564,19 +604,19 @@ static inline void mul_mat_f_switch_ids(
|
||||
dim3 block_nums_ids = block_nums;
|
||||
block_nums_ids.y *= col_tiles;
|
||||
|
||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||
mul_mat_f<T, rows_per_block, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} else {
|
||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
||||
mul_mat_f<T, rows_per_block, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int cols_per_block>
|
||||
template <typename T, int rows_per_block, int cols_per_block>
|
||||
void mul_mat_f_cuda(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
||||
@@ -605,7 +645,7 @@ void mul_mat_f_cuda(
|
||||
|
||||
int64_t nwarps_best = 1;
|
||||
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
|
||||
int64_t max_block_size = 256;
|
||||
int64_t max_block_size = mmf_get_max_block_size(cc);
|
||||
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
|
||||
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
|
||||
if (niter < niter_best) {
|
||||
@@ -614,10 +654,9 @@ void mul_mat_f_cuda(
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
|
||||
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
|
||||
const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
|
||||
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
|
||||
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + mmf_get_padding(cc)) * 4;
|
||||
const int nbytes_cols_per_block_pad = (amd_wmma_available(cc) || amd_mfma_available(cc)) ? tile_B_16::I : tile_B_8::I;
|
||||
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + mmf_get_padding(cc)) * 4;
|
||||
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
||||
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
|
||||
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
|
||||
@@ -628,56 +667,56 @@ void mul_mat_f_cuda(
|
||||
|
||||
switch (nwarps_best) {
|
||||
case 1: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 1>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 1>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 2: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 2>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 2>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 3: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 3>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 3>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 4: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 4>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 4>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 5: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 5>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 5>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 6: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 6>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 6>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 7: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 7>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 7>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 8: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 8>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 8>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
@@ -691,7 +730,7 @@ void mul_mat_f_cuda(
|
||||
GGML_UNUSED_VARS(nchannels_y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, int rows_per_block>
|
||||
static void mul_mat_f_switch_cols_per_block(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
||||
@@ -708,82 +747,82 @@ static void mul_mat_f_switch_cols_per_block(
|
||||
|
||||
switch (ncols_case) {
|
||||
case 1: {
|
||||
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 2: {
|
||||
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 3: {
|
||||
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 4: {
|
||||
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 5: {
|
||||
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 6: {
|
||||
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 7: {
|
||||
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 8: {
|
||||
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 9: {
|
||||
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 10: {
|
||||
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 11: {
|
||||
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 12: {
|
||||
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 13: {
|
||||
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 14: {
|
||||
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 15: {
|
||||
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 16: {
|
||||
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
@@ -793,8 +832,36 @@ static void mul_mat_f_switch_cols_per_block(
|
||||
}
|
||||
}
|
||||
|
||||
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
|
||||
template void mul_mat_f_cuda<T, ncols_dst>( \
|
||||
template <typename T>
|
||||
static void mul_mat_f_switch_rows_per_block(
|
||||
const int rows_per_block, const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t stride_col_id, const int stride_row_id,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
||||
switch (rows_per_block) {
|
||||
case MMF_ROWS_PER_BLOCK: {
|
||||
mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK>(
|
||||
x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case MMF_ROWS_PER_BLOCK_CDNA: {
|
||||
mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK_CDNA>(
|
||||
x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("unsupported rows_per_block: %i", rows_per_block);
|
||||
}
|
||||
}
|
||||
|
||||
#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \
|
||||
template void mul_mat_f_cuda<T, nrows_dst, ncols_dst>( \
|
||||
const T * x, const float * y, const int32_t * ids, float * dst, \
|
||||
const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
|
||||
const int64_t stride_col_id, const int64_t stride_row_id, \
|
||||
@@ -803,16 +870,22 @@ static void mul_mat_f_switch_cols_per_block(
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
|
||||
cudaStream_t stream, const mmf_ids_data * ids_data);
|
||||
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
#if !defined(GGML_USE_MUSA)
|
||||
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
|
||||
extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
|
||||
|
||||
#define DECL_MMF_CASE(ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(float, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(half2, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
|
||||
DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
|
||||
|
||||
DECL_MMF_CASE_EXTERN(1);
|
||||
DECL_MMF_CASE_EXTERN(2);
|
||||
|
||||
@@ -5,6 +5,13 @@
|
||||
#include <cmath>
|
||||
#include <initializer_list>
|
||||
|
||||
// Kernel config struct - passed by value to CUDA kernel
|
||||
struct topk_moe_config {
|
||||
bool use_sigmoid;
|
||||
bool with_norm;
|
||||
bool delayed_softmax;
|
||||
};
|
||||
|
||||
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
|
||||
template <int experts_per_thread, bool use_limit>
|
||||
__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
|
||||
@@ -50,6 +57,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
|
||||
}
|
||||
}
|
||||
|
||||
template <int experts_per_thread, bool use_limit>
|
||||
__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const int idx = lane + i * WARP_SIZE;
|
||||
const bool active = !use_limit || (idx < limit);
|
||||
vals[i] = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
This kernel does the following:
|
||||
1. optionally softmax over the logits per token [n_experts, n_tokens]
|
||||
@@ -59,13 +76,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
|
||||
|
||||
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
|
||||
*/
|
||||
template <int n_experts, bool with_norm, bool delayed_softmax = false>
|
||||
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
|
||||
float * weights,
|
||||
int32_t * ids,
|
||||
const int n_rows,
|
||||
const int n_expert_used,
|
||||
const float clamp_val) {
|
||||
template <int n_experts, bool has_bias>
|
||||
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
|
||||
float * weights,
|
||||
int32_t * ids,
|
||||
float * bias,
|
||||
const int n_rows,
|
||||
const int n_expert_used,
|
||||
const float clamp_val,
|
||||
const float scale_val,
|
||||
const topk_moe_config config) {
|
||||
const int row = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
if (row >= n_rows) {
|
||||
return;
|
||||
@@ -79,14 +99,41 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||
|
||||
float wt[experts_per_thread];
|
||||
|
||||
// Initialize all slots to -INFINITY
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
wt[i] = -INFINITY;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
const int expert = i + threadIdx.x;
|
||||
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
|
||||
}
|
||||
|
||||
if constexpr (!delayed_softmax) {
|
||||
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
|
||||
if (!config.delayed_softmax) {
|
||||
if (config.use_sigmoid) {
|
||||
sigmoid_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
|
||||
} else {
|
||||
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
|
||||
}
|
||||
}
|
||||
|
||||
// selection_wt is only needed when bias is present (selection uses wt + bias)
|
||||
// when no bias, we use wt directly for both selection and weight values
|
||||
float selection_wt[has_bias ? experts_per_thread : 1];
|
||||
|
||||
if constexpr (has_bias) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
selection_wt[i] = -INFINITY;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
const int expert = i + threadIdx.x;
|
||||
selection_wt[i / WARP_SIZE] =
|
||||
(n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
//at this point, each thread holds either a portion of the softmax distribution
|
||||
@@ -106,22 +153,56 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||
float max_val = wt[0];
|
||||
int max_expert = threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 1; i < experts_per_thread; i++) {
|
||||
const int expert = threadIdx.x + i * WARP_SIZE;
|
||||
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
|
||||
max_val = wt[i];
|
||||
max_expert = expert;
|
||||
}
|
||||
}
|
||||
if constexpr (has_bias) {
|
||||
float max_val_s = selection_wt[0];
|
||||
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
|
||||
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
|
||||
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
|
||||
if (val > max_val || (val == max_val && expert < max_expert)) {
|
||||
max_val = val;
|
||||
max_expert = expert;
|
||||
for (int i = 1; i < experts_per_thread; i++) {
|
||||
const int expert = threadIdx.x + i * WARP_SIZE;
|
||||
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) {
|
||||
max_val = wt[i];
|
||||
max_val_s = selection_wt[i];
|
||||
max_expert = expert;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
|
||||
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
|
||||
const float val_s = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE);
|
||||
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
|
||||
if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
|
||||
max_val = val;
|
||||
max_val_s = val_s;
|
||||
max_expert = expert;
|
||||
}
|
||||
}
|
||||
|
||||
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
|
||||
selection_wt[max_expert / WARP_SIZE] = -INFINITY;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 1; i < experts_per_thread; i++) {
|
||||
const int expert = threadIdx.x + i * WARP_SIZE;
|
||||
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
|
||||
max_val = wt[i];
|
||||
max_expert = expert;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
|
||||
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
|
||||
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
|
||||
if (val > max_val || (val == max_val && expert < max_expert)) {
|
||||
max_val = val;
|
||||
max_expert = expert;
|
||||
}
|
||||
}
|
||||
|
||||
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
|
||||
wt[max_expert / WARP_SIZE] = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,16 +211,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||
}
|
||||
|
||||
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
|
||||
wt[max_expert / WARP_SIZE] = -INFINITY;
|
||||
|
||||
ids[k] = max_expert;
|
||||
if constexpr (with_norm) {
|
||||
if (config.with_norm) {
|
||||
wt_sum += max_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (with_norm) {
|
||||
if (config.with_norm) {
|
||||
wt_sum = warp_reduce_sum(wt_sum);
|
||||
wt_sum = max(wt_sum, clamp_val);
|
||||
const float inv_sum = 1.0f / wt_sum;
|
||||
@@ -149,7 +228,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (delayed_softmax) {
|
||||
if (config.delayed_softmax) {
|
||||
softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
|
||||
}
|
||||
|
||||
@@ -157,25 +236,25 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const int idx = i * WARP_SIZE + threadIdx.x;
|
||||
if (idx < n_expert_used) {
|
||||
weights[idx] = output_weights[i];
|
||||
weights[idx] = output_weights[i] * scale_val;
|
||||
}
|
||||
}
|
||||
|
||||
if (!with_norm) {
|
||||
GGML_UNUSED(clamp_val);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool with_norm, bool delayed_softmax = false>
|
||||
template<bool has_bias>
|
||||
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
||||
const float * logits,
|
||||
float * weights,
|
||||
int32_t * ids,
|
||||
float * bias,
|
||||
const int n_rows,
|
||||
const int n_expert,
|
||||
const int n_expert_used,
|
||||
const float clamp_val) {
|
||||
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
|
||||
const float clamp_val,
|
||||
const float scale_val,
|
||||
const topk_moe_config config) {
|
||||
GGML_ASSERT(!(config.with_norm && config.delayed_softmax) &&
|
||||
"delayed softmax is not supported with weight normalization");
|
||||
const int rows_per_block = 4;
|
||||
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
|
||||
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
|
||||
@@ -183,44 +262,48 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
||||
|
||||
switch (n_expert) {
|
||||
case 1:
|
||||
topk_moe_cuda<1, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
topk_moe_cuda<1, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
case 2:
|
||||
topk_moe_cuda<2, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
topk_moe_cuda<2, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
case 4:
|
||||
topk_moe_cuda<4, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
topk_moe_cuda<4, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
case 8:
|
||||
topk_moe_cuda<8, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
topk_moe_cuda<8, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
case 16:
|
||||
topk_moe_cuda<16, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
topk_moe_cuda<16, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
case 32:
|
||||
topk_moe_cuda<32, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
topk_moe_cuda<32, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
case 64:
|
||||
topk_moe_cuda<64, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
topk_moe_cuda<64, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
case 128:
|
||||
topk_moe_cuda<128, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
topk_moe_cuda<128, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
case 256:
|
||||
topk_moe_cuda<256, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
topk_moe_cuda<256, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
case 512:
|
||||
topk_moe_cuda<512, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
topk_moe_cuda<512, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
case 576:
|
||||
topk_moe_cuda<576, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "fatal error");
|
||||
@@ -228,13 +311,14 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * logits,
|
||||
ggml_tensor * weights,
|
||||
ggml_tensor * ids,
|
||||
const bool with_norm,
|
||||
const bool delayed_softmax,
|
||||
ggml_tensor * clamp) {
|
||||
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * logits,
|
||||
ggml_tensor * weights,
|
||||
ggml_tensor * ids,
|
||||
const ggml_tensor * clamp,
|
||||
const ggml_tensor * scale,
|
||||
const ggml_tensor * bias,
|
||||
const ggml_cuda_topk_moe_args & args) {
|
||||
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
@@ -245,107 +329,75 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||
const float * logits_d = (const float *) logits->data;
|
||||
float * weights_d = (float *) weights->data;
|
||||
int32_t * ids_d = (int32_t *) ids->data;
|
||||
float * bias_d = bias ? (float *) bias->data : nullptr;
|
||||
|
||||
float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f;
|
||||
|
||||
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
|
||||
|
||||
const int n_expert_used = weights->ne[1];
|
||||
|
||||
const bool with_norm = clamp != nullptr;
|
||||
|
||||
float clamp_val = -INFINITY;
|
||||
if (with_norm) {
|
||||
if (clamp) {
|
||||
clamp_val = ggml_get_op_params_f32(clamp, 0);
|
||||
}
|
||||
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
|
||||
if (clamp) {
|
||||
clamp_val = ggml_get_op_params_f32(clamp, 0);
|
||||
}
|
||||
|
||||
topk_moe_config config;
|
||||
config.use_sigmoid = args.sigmoid;
|
||||
config.with_norm = with_norm;
|
||||
config.delayed_softmax = args.delayed_softmax;
|
||||
|
||||
if (bias) {
|
||||
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
|
||||
scale_val, config);
|
||||
} else {
|
||||
GGML_ASSERT(clamp == nullptr);
|
||||
if (delayed_softmax) {
|
||||
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
|
||||
clamp_val);
|
||||
} else {
|
||||
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
|
||||
clamp_val);
|
||||
}
|
||||
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
|
||||
scale_val, config);
|
||||
}
|
||||
}
|
||||
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
|
||||
const ggml_tensor * weights,
|
||||
const ggml_tensor * get_rows,
|
||||
const ggml_tensor * argsort,
|
||||
const ggml_tensor * clamp,
|
||||
int n_expert) {
|
||||
ggml_tensor * probs = get_rows->src[0];
|
||||
if (probs->op != GGML_OP_RESHAPE) {
|
||||
return false;
|
||||
}
|
||||
probs = probs->src[0];
|
||||
ggml_tensor * selection_probs = argsort->src[0];
|
||||
|
||||
if (probs != selection_probs) {
|
||||
const ggml_tensor * logits,
|
||||
const ggml_tensor * ids) {
|
||||
const int n_expert = ids->nb[1] / ids->nb[0];
|
||||
if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) {
|
||||
return false;
|
||||
}
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
|
||||
|
||||
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
|
||||
if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (scale != 1.0f || max_bias != 0.0f) {
|
||||
return false;
|
||||
}
|
||||
if (gating_op->op == GGML_OP_SOFT_MAX) {
|
||||
const ggml_tensor * softmax = gating_op;
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
// don't fuse when masks or sinks are present
|
||||
if (softmax->src[1] || softmax->src[2]) {
|
||||
return false;
|
||||
}
|
||||
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
|
||||
|
||||
// n_expert must be a power of 2
|
||||
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (clamp) {
|
||||
if (clamp->op != GGML_OP_CLAMP) {
|
||||
if (!ggml_is_contiguous(softmax->src[0])) {
|
||||
return false;
|
||||
}
|
||||
float max_val = ggml_get_op_params_f32(clamp, 1);
|
||||
|
||||
if (max_val != INFINITY) {
|
||||
if (scale != 1.0f || max_bias != 0.0f) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// don't fuse when masks or sinks are present
|
||||
if (softmax->src[1] || softmax->src[2]) {
|
||||
return false;
|
||||
}
|
||||
} else if (gating_op->op == GGML_OP_UNARY) {
|
||||
ggml_unary_op op = ggml_get_unary_op(gating_op);
|
||||
|
||||
if (op != GGML_UNARY_OP_SIGMOID) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
|
||||
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
|
||||
GGML_OP_RESHAPE };
|
||||
|
||||
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
||||
|
||||
static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
|
||||
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
|
||||
|
||||
GGML_ASSERT(!norm || !delayed_softmax);
|
||||
|
||||
if (delayed_softmax) {
|
||||
return delayed_softmax_ops;
|
||||
}
|
||||
|
||||
if (norm) {
|
||||
return norm_ops;
|
||||
}
|
||||
|
||||
return no_norm_ops;
|
||||
}
|
||||
|
||||
@@ -3,19 +3,25 @@
|
||||
|
||||
#include <initializer_list>
|
||||
|
||||
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * logits,
|
||||
ggml_tensor * weights,
|
||||
ggml_tensor * ids,
|
||||
const bool with_norm,
|
||||
const bool delayed_softmax = false,
|
||||
ggml_tensor * weight_clamp = nullptr);
|
||||
struct ggml_cuda_topk_moe_args {
|
||||
bool sigmoid{};
|
||||
bool softmax{};
|
||||
bool delayed_softmax{};
|
||||
bool prob_bias{};
|
||||
bool norm{};
|
||||
bool scale{};
|
||||
};
|
||||
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
|
||||
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * logits,
|
||||
ggml_tensor * weights,
|
||||
ggml_tensor * ids,
|
||||
const ggml_tensor * clamp,
|
||||
const ggml_tensor * scale,
|
||||
const ggml_tensor * bias,
|
||||
const ggml_cuda_topk_moe_args & args);
|
||||
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
|
||||
const ggml_tensor * weights,
|
||||
const ggml_tensor * get_rows,
|
||||
const ggml_tensor * argsort,
|
||||
const ggml_tensor * clamp,
|
||||
int n_expert);
|
||||
|
||||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
|
||||
const ggml_tensor * logits,
|
||||
const ggml_tensor * ids);
|
||||
|
||||
@@ -1,7 +1,17 @@
|
||||
file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT)
|
||||
file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
|
||||
|
||||
if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}" OR NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
|
||||
message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT and HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.")
|
||||
endif()
|
||||
|
||||
message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for building libggml-htp skels")
|
||||
|
||||
include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
|
||||
include(ExternalProject)
|
||||
|
||||
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
|
||||
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
|
||||
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)")
|
||||
|
||||
add_library(htp_iface OBJECT
|
||||
@@ -25,56 +35,71 @@ else()
|
||||
target_link_options(htp_iface PUBLIC -ldl)
|
||||
endif()
|
||||
|
||||
link_custom_library(htp_iface cdsprpc)
|
||||
link_custom_library(htp_iface rpcmem)
|
||||
|
||||
set(TARGET_NAME ggml-hexagon)
|
||||
ggml_add_backend_library(${TARGET_NAME}
|
||||
ggml-hexagon.cpp htp-utils.c htp-utils.h ../../include/ggml-hexagon.h)
|
||||
ggml-hexagon.cpp
|
||||
htp-drv.cpp
|
||||
htp-drv.h
|
||||
libdl.h
|
||||
../../include/ggml-hexagon.h)
|
||||
|
||||
target_link_libraries(${TARGET_NAME} PRIVATE htp_iface)
|
||||
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/htp ${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
# Build HTP bits
|
||||
set(HTP_CMAKE_ARGS
|
||||
-DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR}
|
||||
-DHEXAGON_SDK_ROOT=$ENV{HEXAGON_SDK_ROOT}
|
||||
-DHEXAGON_TOOLS_ROOT=$ENV{HEXAGON_TOOLS_ROOT}
|
||||
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}
|
||||
-DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
|
||||
# Build HTP skels
|
||||
set(HTP_SKELS)
|
||||
function(build_htp_skel V)
|
||||
ExternalProject_Add(htp-${V}
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
BUILD_BYPRODUCTS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so
|
||||
CMAKE_ARGS
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake
|
||||
-DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR}
|
||||
-DHEXAGON_SDK_ROOT=${HEXAGON_SDK_ROOT}
|
||||
-DHEXAGON_TOOLS_ROOT=${HEXAGON_TOOLS_ROOT}
|
||||
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}
|
||||
-DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}
|
||||
-DDSP_VERSION=${V}
|
||||
-DPREBUILT_LIB_DIR="toolv19_${V}")
|
||||
list(APPEND HTP_SKELS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so)
|
||||
set(HTP_SKELS ${HTP_SKELS} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
ExternalProject_Add(htp-v68
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v68 -DPREBUILT_LIB_DIR="toolv19_v68")
|
||||
|
||||
ExternalProject_Add(htp-v69
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v69 -DPREBUILT_LIB_DIR="toolv19_v69")
|
||||
|
||||
ExternalProject_Add(htp-v73
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v73 -DPREBUILT_LIB_DIR="toolv19_v73")
|
||||
|
||||
ExternalProject_Add(htp-v75
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v75 -DPREBUILT_LIB_DIR="toolv19_v75")
|
||||
|
||||
ExternalProject_Add(htp-v79
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v79 -DPREBUILT_LIB_DIR="toolv19_v79")
|
||||
|
||||
ExternalProject_Add(htp-v81
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v81 -DPREBUILT_LIB_DIR="toolv19_v81")
|
||||
build_htp_skel(v68)
|
||||
build_htp_skel(v69)
|
||||
build_htp_skel(v73)
|
||||
build_htp_skel(v75)
|
||||
build_htp_skel(v79)
|
||||
build_htp_skel(v81)
|
||||
|
||||
# Install Hexagon skels required at runtime
|
||||
install(FILES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v68.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v69.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v73.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v75.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v79.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v81.so
|
||||
TYPE LIB)
|
||||
install(FILES ${HTP_SKELS} TYPE LIB)
|
||||
|
||||
if (CMAKE_SYSTEM_NAME MATCHES Windows AND GGML_HEXAGON_HTP_CERT)
|
||||
file(TO_CMAKE_PATH "$ENV{WINDOWS_SDK_BIN}/arm64" WINSDK_BIN0_ARM64)
|
||||
file(TO_CMAKE_PATH "$ENV{WINDOWS_SDK_BIN}/x86" WINSDK_BIN0_X86)
|
||||
file(TO_CMAKE_PATH "$ENV{WindowsSdkVerBinPath}/arm64" WINSDK_BIN1_ARM64)
|
||||
file(TO_CMAKE_PATH "$ENV{WindowsSdkVerBinPath}/x86" WINSDK_BIN1_X86)
|
||||
|
||||
set(WINSDK_PATHS ${WINSDK_BIN0_ARM64} ${WINSDK_BIN0_X86} ${WINSDK_BIN1_ARM64} ${WINSDK_BIN1_X86})
|
||||
|
||||
find_program(INF2CAT NAMES inf2cat.exe PATHS ${WINSDK_PATHS} REQUIRED)
|
||||
find_program(SIGNTOOL NAMES signtool.exe PATHS ${WINSDK_PATHS} REQUIRED)
|
||||
|
||||
message(STATUS "hexagon: using ${GGML_HEXAGON_HTP_CERT} to sign libggml-htp skels")
|
||||
|
||||
set(LIBGGML_HTP_CAT ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp.cat)
|
||||
add_custom_target(libggml-htp-cat
|
||||
BYPRODUCTS ${LIBGGML_HTP_CAT}
|
||||
DEPENDS libggml-htp.inf ${HTP_SKELS}
|
||||
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/libggml-htp.inf ${CMAKE_CURRENT_BINARY_DIR}
|
||||
COMMAND ${INF2CAT} /driver:${CMAKE_CURRENT_BINARY_DIR} /os:10_25H2_ARM64
|
||||
COMMAND ${SIGNTOOL} sign /fd sha256 /f ${GGML_HEXAGON_HTP_CERT} ${LIBGGML_HTP_CAT}
|
||||
COMMENT "generating and signing libggml-htp.cat file"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
add_dependencies(${TARGET_NAME} libggml-htp-cat)
|
||||
install(FILES ${LIBGGML_HTP_CAT} TYPE LIB)
|
||||
endif()
|
||||
|
||||
@@ -14,9 +14,6 @@
|
||||
|
||||
#ifdef _WIN32
|
||||
# include <sal.h>
|
||||
# ifndef _WINDOWS
|
||||
# define _WINDOWS
|
||||
# endif
|
||||
#else
|
||||
# include <semaphore.h>
|
||||
# include <unistd.h>
|
||||
@@ -25,8 +22,6 @@
|
||||
#pragma clang diagnostic ignored "-Wnested-anon-types"
|
||||
#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
|
||||
|
||||
#include "htp-utils.h"
|
||||
|
||||
#include <AEEStdErr.h>
|
||||
#include <dspqueue.h>
|
||||
#include <rpcmem.h>
|
||||
@@ -40,6 +35,7 @@
|
||||
#include "op-desc.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp_iface.h"
|
||||
#include "htp-drv.h"
|
||||
|
||||
static size_t opt_ndev = 1;
|
||||
static size_t opt_nhvx = 0; // use all
|
||||
@@ -150,9 +146,9 @@ void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_
|
||||
0, // flags - the framework will autoset this
|
||||
n_bufs, // number of buffers
|
||||
bufs, // buffer references
|
||||
sizeof(req),
|
||||
sizeof(req), // Message length
|
||||
(const uint8_t *) &req, // Message
|
||||
1000000 // Timeout
|
||||
DSPQUEUE_TIMEOUT // Timeout
|
||||
);
|
||||
|
||||
if (err != 0) {
|
||||
@@ -182,13 +178,13 @@ void ggml_hexagon_session::flush() {
|
||||
|
||||
// Read response packet from queue
|
||||
int err = dspqueue_read(q, &flags,
|
||||
HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references
|
||||
&n_bufs, // Number of buffer references
|
||||
bufs, // Buffer references
|
||||
sizeof(rsp), // Max message length
|
||||
&rsp_size, // Message length
|
||||
(uint8_t *) &rsp,
|
||||
1000000); // Timeout
|
||||
HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references
|
||||
&n_bufs, // Number of buffer references
|
||||
bufs, // Buffer references
|
||||
sizeof(rsp), // Max message length
|
||||
&rsp_size, // Message length
|
||||
(uint8_t *) &rsp, // Message
|
||||
DSPQUEUE_TIMEOUT); // Timeout
|
||||
|
||||
if (err == AEE_EEXPIRED) {
|
||||
// TODO: might need to bail out if the HTP is stuck on something
|
||||
@@ -269,13 +265,7 @@ struct ggml_backend_hexagon_buffer_context {
|
||||
ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) {
|
||||
size += 4 * 1024; // extra page for padding
|
||||
|
||||
if (rpcmem_alloc2) {
|
||||
this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
|
||||
} else {
|
||||
GGML_LOG_INFO("ggml-hex: %s rpcmem_alloc2 not found, falling back to rpcmem_alloc\n", sess->name.c_str());
|
||||
this->base = (uint8_t *) rpcmem_alloc(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
|
||||
}
|
||||
|
||||
this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
|
||||
if (!this->base) {
|
||||
GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->name.c_str(), size);
|
||||
throw std::runtime_error("ggml-hex: rpcmem_alloc failed (see log for details)");
|
||||
@@ -2461,12 +2451,12 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) {
|
||||
}
|
||||
|
||||
static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) {
|
||||
return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type) && ggml_is_quantized(op1->src[1]->type));
|
||||
return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type));
|
||||
}
|
||||
|
||||
static inline bool is_compute_op(ggml_tensor *node)
|
||||
{
|
||||
return !(ggml_op_is_empty(node->op) || ggml_is_empty(node));
|
||||
return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE);
|
||||
}
|
||||
|
||||
// scan the graph and figure out last compute op index
|
||||
@@ -2488,7 +2478,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
|
||||
const int last = last_compute_op(graph);
|
||||
|
||||
const struct ggml_tensor * prev_quant_op = nullptr; // prev executed op with quantizer
|
||||
const struct ggml_tensor * prev_op = nullptr; // prev executed op
|
||||
|
||||
for (int i = 0; i < graph->n_nodes; ++i) {
|
||||
ggml_tensor * node = graph->nodes[i];
|
||||
@@ -2497,17 +2487,15 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
continue;
|
||||
}
|
||||
|
||||
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t flags = 0;
|
||||
|
||||
// skip quantizer if src1 is reused
|
||||
if (op_reuse_src1(node, prev_quant_op)) {
|
||||
if (op_reuse_src1(node, prev_op)) {
|
||||
flags |= HTP_OPFLAGS_SKIP_QUANTIZE;
|
||||
}
|
||||
|
||||
prev_op = node;
|
||||
|
||||
// ask for early notification for the last Op
|
||||
if (i == last) {
|
||||
flags |= HTP_OPFLAGS_EARLY_WAKEUP;
|
||||
@@ -2520,7 +2508,6 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
} else {
|
||||
ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);
|
||||
}
|
||||
prev_quant_op = node;
|
||||
break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
if (ggml_is_quantized(node->src[0]->type)) {
|
||||
@@ -2528,7 +2515,6 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
} else {
|
||||
ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);
|
||||
}
|
||||
prev_quant_op = node;
|
||||
break;
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_ADD:
|
||||
@@ -2670,7 +2656,7 @@ static std::vector<int> ggml_hexagon_graph_optimize_reorder(const std::vector<no
|
||||
}
|
||||
|
||||
// that many nodes forward to search for stackable nodes that can reuse VTCM
|
||||
constexpr int N_FORWARD = 8;
|
||||
constexpr int N_FORWARD = 16;
|
||||
|
||||
for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
|
||||
if (used[i1]) {
|
||||
@@ -3056,10 +3042,12 @@ ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) {
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
if (opt_arch < 75) {
|
||||
opt_ndev = 1;
|
||||
GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch);
|
||||
|
||||
@@ -3156,6 +3144,8 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
opt_arch = strtoul(str_arch, NULL, 0);
|
||||
}
|
||||
|
||||
opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : 1;
|
||||
|
||||
reg->context = new ggml_hexagon_registry(reg);
|
||||
|
||||
HEX_VERBOSE("ggml-hex: size-of-general-req %zu size-of-general-rsp %zu\n", sizeof(struct htp_general_req),
|
||||
@@ -3180,6 +3170,11 @@ ggml_backend_reg_t ggml_backend_hexagon_reg(void) {
|
||||
static std::mutex mutex;
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
if (!initialized) {
|
||||
auto nErr = htpdrv_init();
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
ggml_hexagon_init(®);
|
||||
}
|
||||
|
||||
|
||||
418
ggml/src/ggml-hexagon/htp-drv.cpp
Normal file
418
ggml/src/ggml-hexagon/htp-drv.cpp
Normal file
@@ -0,0 +1,418 @@
|
||||
// sample drv interface
|
||||
|
||||
#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
|
||||
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||
#pragma clang diagnostic ignored "-Wsign-compare"
|
||||
|
||||
#include <filesystem>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#ifdef _WIN32
|
||||
# define WIN32_LEAN_AND_MEAN
|
||||
# ifndef NOMINMAX
|
||||
# define NOMINMAX
|
||||
# endif
|
||||
# include <windows.h>
|
||||
# include <winevt.h>
|
||||
#else
|
||||
# include <dlfcn.h>
|
||||
# include <unistd.h>
|
||||
#endif
|
||||
#include "ggml-impl.h"
|
||||
#include "htp-drv.h"
|
||||
#include "libdl.h"
|
||||
|
||||
#include <domain.h>
|
||||
|
||||
//
|
||||
// Driver API types
|
||||
//
|
||||
|
||||
typedef void * (*rpcmem_alloc_pfn_t)(int heapid, uint32_t flags, int size);
|
||||
typedef void * (*rpcmem_alloc2_pfn_t)(int heapid, uint32_t flags, size_t size);
|
||||
typedef void (*rpcmem_free_pfn_t)(void * po);
|
||||
typedef int (*rpcmem_to_fd_pfn_t)(void * po);
|
||||
|
||||
typedef AEEResult (*dspqueue_create_pfn_t)(int domain,
|
||||
uint32_t flags,
|
||||
uint32_t req_queue_size,
|
||||
uint32_t resp_queue_size,
|
||||
dspqueue_callback_t packet_callback,
|
||||
dspqueue_callback_t error_callback,
|
||||
void * callback_context,
|
||||
dspqueue_t * queue);
|
||||
typedef AEEResult (*dspqueue_close_pfn_t)(dspqueue_t queue);
|
||||
typedef AEEResult (*dspqueue_export_pfn_t)(dspqueue_t queue, uint64_t *queue_id);
|
||||
typedef AEEResult (*dspqueue_write_pfn_t)(dspqueue_t queue, uint32_t flags,
|
||||
uint32_t num_buffers,
|
||||
struct dspqueue_buffer *buffers,
|
||||
uint32_t message_length,
|
||||
const uint8_t *message,
|
||||
uint32_t timeout_us);
|
||||
typedef AEEResult (*dspqueue_read_pfn_t)(dspqueue_t queue, uint32_t *flags,
|
||||
uint32_t max_buffers, uint32_t *num_buffers,
|
||||
struct dspqueue_buffer *buffers,
|
||||
uint32_t max_message_length,
|
||||
uint32_t *message_length, uint8_t *message,
|
||||
uint32_t timeout_us);
|
||||
|
||||
typedef int (*fastrpc_mmap_pfn_t)(int domain, int fd, void *addr, int offset, size_t length, enum fastrpc_map_flags flags);
|
||||
typedef int (*fastrpc_munmap_pfn_t)(int domain, int fd, void *addr, size_t length);
|
||||
|
||||
typedef int (*remote_handle64_open_pfn_t)(const char* name, remote_handle64 *ph);
|
||||
typedef int (*remote_handle64_invoke_pfn_t)(remote_handle64 h, uint32_t dwScalars, remote_arg *pra);
|
||||
typedef int (*remote_handle64_close_pfn_t)(remote_handle h);
|
||||
typedef int (*remote_handle_control_pfn_t)(uint32_t req, void* data, uint32_t datalen);
|
||||
typedef int (*remote_handle64_control_pfn_t)(remote_handle64 h, uint32_t req, void* data, uint32_t datalen);
|
||||
typedef int (*remote_session_control_pfn_t)(uint32_t req, void *data, uint32_t datalen);
|
||||
|
||||
//
|
||||
// Driver API pfns
|
||||
//
|
||||
|
||||
rpcmem_alloc_pfn_t rpcmem_alloc_pfn = nullptr;
|
||||
rpcmem_alloc2_pfn_t rpcmem_alloc2_pfn = nullptr;
|
||||
rpcmem_free_pfn_t rpcmem_free_pfn = nullptr;
|
||||
rpcmem_to_fd_pfn_t rpcmem_to_fd_pfn = nullptr;
|
||||
|
||||
fastrpc_mmap_pfn_t fastrpc_mmap_pfn = nullptr;
|
||||
fastrpc_munmap_pfn_t fastrpc_munmap_pfn = nullptr;
|
||||
|
||||
dspqueue_create_pfn_t dspqueue_create_pfn = nullptr;
|
||||
dspqueue_close_pfn_t dspqueue_close_pfn = nullptr;
|
||||
dspqueue_export_pfn_t dspqueue_export_pfn = nullptr;
|
||||
dspqueue_write_pfn_t dspqueue_write_pfn = nullptr;
|
||||
dspqueue_read_pfn_t dspqueue_read_pfn = nullptr;
|
||||
|
||||
remote_handle64_open_pfn_t remote_handle64_open_pfn = nullptr;
|
||||
remote_handle64_invoke_pfn_t remote_handle64_invoke_pfn = nullptr;
|
||||
remote_handle64_close_pfn_t remote_handle64_close_pfn = nullptr;
|
||||
remote_handle_control_pfn_t remote_handle_control_pfn = nullptr;
|
||||
remote_handle64_control_pfn_t remote_handle64_control_pfn = nullptr;
|
||||
remote_session_control_pfn_t remote_session_control_pfn = nullptr;
|
||||
|
||||
//
|
||||
// Driver API
|
||||
//
|
||||
|
||||
void * rpcmem_alloc(int heapid, uint32_t flags, int size) {
|
||||
return rpcmem_alloc_pfn(heapid, flags, size);
|
||||
}
|
||||
|
||||
void * rpcmem_alloc2(int heapid, uint32_t flags, size_t size) {
|
||||
if (rpcmem_alloc2_pfn) {
|
||||
return rpcmem_alloc2_pfn(heapid, flags, size);
|
||||
} else {
|
||||
GGML_LOG_INFO("ggml-hex: rpcmem_alloc2 not found, falling back to rpcmem_alloc\n");
|
||||
return rpcmem_alloc_pfn(heapid, flags, size);
|
||||
}
|
||||
}
|
||||
|
||||
void rpcmem_free(void * po) {
|
||||
return rpcmem_free_pfn(po);
|
||||
}
|
||||
|
||||
int rpcmem_to_fd(void * po) {
|
||||
return rpcmem_to_fd_pfn(po);
|
||||
}
|
||||
|
||||
HTPDRV_API int fastrpc_mmap(int domain, int fd, void * addr, int offset, size_t length, enum fastrpc_map_flags flags) {
|
||||
return fastrpc_mmap_pfn(domain, fd, addr, offset, length, flags);
|
||||
}
|
||||
|
||||
HTPDRV_API int fastrpc_munmap(int domain, int fd, void * addr, size_t length) {
|
||||
return fastrpc_munmap_pfn(domain, fd, addr, length);
|
||||
}
|
||||
|
||||
AEEResult dspqueue_create(int domain,
|
||||
uint32_t flags,
|
||||
uint32_t req_queue_size,
|
||||
uint32_t resp_queue_size,
|
||||
dspqueue_callback_t packet_callback,
|
||||
dspqueue_callback_t error_callback,
|
||||
void * callback_context,
|
||||
dspqueue_t * queue) {
|
||||
return dspqueue_create_pfn(domain, flags, req_queue_size, resp_queue_size, packet_callback, error_callback,
|
||||
callback_context, queue);
|
||||
}
|
||||
|
||||
AEEResult dspqueue_close(dspqueue_t queue) {
|
||||
return dspqueue_close_pfn(queue);
|
||||
}
|
||||
|
||||
AEEResult dspqueue_export(dspqueue_t queue, uint64_t * queue_id) {
|
||||
return dspqueue_export_pfn(queue, queue_id);
|
||||
}
|
||||
|
||||
AEEResult dspqueue_write(dspqueue_t queue,
|
||||
uint32_t flags,
|
||||
uint32_t num_buffers,
|
||||
struct dspqueue_buffer * buffers,
|
||||
uint32_t message_length,
|
||||
const uint8_t * message,
|
||||
uint32_t timeout_us) {
|
||||
return dspqueue_write_pfn(queue, flags, num_buffers, buffers, message_length, message, timeout_us);
|
||||
}
|
||||
|
||||
AEEResult dspqueue_read(dspqueue_t queue,
|
||||
uint32_t * flags,
|
||||
uint32_t max_buffers,
|
||||
uint32_t * num_buffers,
|
||||
struct dspqueue_buffer * buffers,
|
||||
uint32_t max_message_length,
|
||||
uint32_t * message_length,
|
||||
uint8_t * message,
|
||||
uint32_t timeout_us) {
|
||||
return dspqueue_read_pfn(queue, flags, max_buffers, num_buffers, buffers, max_message_length, message_length,
|
||||
message, timeout_us);
|
||||
}
|
||||
|
||||
HTPDRV_API int remote_handle64_open(const char * name, remote_handle64 * ph) {
|
||||
return remote_handle64_open_pfn(name, ph);
|
||||
}
|
||||
|
||||
HTPDRV_API int remote_handle64_invoke(remote_handle64 h, uint32_t dwScalars, remote_arg * pra) {
|
||||
return remote_handle64_invoke_pfn(h, dwScalars, pra);
|
||||
}
|
||||
|
||||
HTPDRV_API int remote_handle64_close(remote_handle64 h) {
|
||||
return remote_handle64_close_pfn(h);
|
||||
}
|
||||
|
||||
HTPDRV_API int remote_handle_control(uint32_t req, void * data, uint32_t datalen) {
|
||||
return remote_handle_control_pfn(req, data, datalen);
|
||||
}
|
||||
|
||||
HTPDRV_API int remote_handle64_control(remote_handle64 h, uint32_t req, void * data, uint32_t datalen) {
|
||||
return remote_handle64_control_pfn(h, req, data, datalen);
|
||||
}
|
||||
|
||||
HTPDRV_API int remote_session_control(uint32_t req, void * data, uint32_t datalen) {
|
||||
return remote_session_control_pfn(req, data, datalen);
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
static std::string wstr_to_str(std::wstring_view wstr) {
|
||||
std::string result;
|
||||
if (wstr.empty()) {
|
||||
return result;
|
||||
}
|
||||
auto bytes_needed = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS,
|
||||
wstr.data(), (int) wstr.size(),
|
||||
nullptr, 0, nullptr, nullptr);
|
||||
if (bytes_needed == 0) {
|
||||
GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError());
|
||||
throw std::runtime_error("Invalid wstring input");
|
||||
}
|
||||
|
||||
result.resize(bytes_needed, '\0');
|
||||
int bytes_written = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS,
|
||||
wstr.data(), (int) wstr.size(),
|
||||
result.data(), bytes_needed,
|
||||
nullptr, nullptr);
|
||||
if (bytes_written == 0) {
|
||||
GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError());
|
||||
throw std::runtime_error("Wstring conversion failed");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::string get_driver_path() {
|
||||
std::wstring serviceName = L"qcnspmcdm";
|
||||
std::string result;
|
||||
|
||||
// Get a handle to the SCM database.
|
||||
SC_HANDLE schSCManager = OpenSCManagerW(NULL, NULL, STANDARD_RIGHTS_READ);
|
||||
if (nullptr == schSCManager) {
|
||||
GGML_LOG_ERROR("ggml-hex: Failed to open SCManager. Error: %lu\n", GetLastError());
|
||||
return result;
|
||||
}
|
||||
|
||||
// Get a handle to the service.
|
||||
SC_HANDLE schService = OpenServiceW(schSCManager, // SCM database
|
||||
serviceName.c_str(), // name of service
|
||||
SERVICE_QUERY_CONFIG); // need query config access
|
||||
|
||||
if (nullptr == schService) {
|
||||
GGML_LOG_ERROR("ggml-hex: Failed to open qcnspmcdm service. Error: %lu\n", GetLastError());
|
||||
CloseServiceHandle(schSCManager);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Store the size of buffer used as an output.
|
||||
DWORD bufferSize;
|
||||
if (!QueryServiceConfigW(schService, NULL, 0, &bufferSize) &&
|
||||
(GetLastError() != ERROR_INSUFFICIENT_BUFFER)) {
|
||||
GGML_LOG_ERROR("ggml-hex: Failed to query service config. Error: %lu\n", GetLastError());
|
||||
CloseServiceHandle(schService);
|
||||
CloseServiceHandle(schSCManager);
|
||||
return result;
|
||||
}
|
||||
// Get the configuration of the service.
|
||||
LPQUERY_SERVICE_CONFIGW serviceConfig =
|
||||
static_cast<LPQUERY_SERVICE_CONFIGW>(LocalAlloc(LMEM_FIXED, bufferSize));
|
||||
if (!QueryServiceConfigW(schService, serviceConfig, bufferSize, &bufferSize)) {
|
||||
fprintf(stderr, "ggml-hex: Failed to query service config. Error: %lu\n", GetLastError());
|
||||
LocalFree(serviceConfig);
|
||||
CloseServiceHandle(schService);
|
||||
CloseServiceHandle(schSCManager);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Read the driver file path get its parent directory
|
||||
std::wstring driverPath = std::wstring(serviceConfig->lpBinaryPathName);
|
||||
driverPath = driverPath.substr(0, driverPath.find_last_of(L"\\"));
|
||||
|
||||
// Clean up resources
|
||||
LocalFree(serviceConfig);
|
||||
CloseServiceHandle(schService);
|
||||
CloseServiceHandle(schSCManager);
|
||||
|
||||
// Driver path would contain invalid path string, like:
|
||||
// \SystemRoot\System32\DriverStore\FileRepository\qcadsprpc8280.inf_arm64_c2b9460c9a072f37
|
||||
// "\SystemRoot" should be replace with a correct one (e.g. C:\Windows)
|
||||
const std::wstring systemRootPlaceholder = L"\\SystemRoot";
|
||||
if (0 != driverPath.compare(0, systemRootPlaceholder.length(), systemRootPlaceholder)) {
|
||||
GGML_LOG_ERROR("ggml-hex: String pattern not found in driver path.\n");
|
||||
return result;
|
||||
}
|
||||
|
||||
// Replace \SystemRoot with an absolute path from system ENV windir
|
||||
const std::wstring systemRootEnv = L"windir";
|
||||
|
||||
// Query the number of wide charactors this variable requires
|
||||
DWORD numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), NULL, 0);
|
||||
if (numWords == 0) {
|
||||
GGML_LOG_ERROR("ggml-hex: Failed get systemRoot environment variable\n");
|
||||
return result;
|
||||
}
|
||||
|
||||
// Query the actual system root name from environment variable
|
||||
std::vector<wchar_t> systemRoot(numWords + 1);
|
||||
numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), systemRoot.data(), numWords + 1);
|
||||
if (numWords == 0) {
|
||||
GGML_LOG_ERROR("ggml-hex: Failed to read windir environment variable\n");
|
||||
return result;
|
||||
}
|
||||
driverPath.replace(0, systemRootPlaceholder.length(), std::wstring(systemRoot.data()));
|
||||
|
||||
return wstr_to_str(driverPath);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
|
||||
|
||||
int htpdrv_init() {
|
||||
static dl_handle_ptr lib_cdsp_rpc_handle = nullptr;
|
||||
static bool initialized = false;
|
||||
#ifdef _WIN32
|
||||
std::string drv_path = get_driver_path() + "\\" + "libcdsprpc.dll";
|
||||
#else
|
||||
std::string drv_path = "libcdsprpc.so";
|
||||
#endif
|
||||
if (initialized) {
|
||||
GGML_LOG_INFO("ggml-hex: Driver already loaded\n");
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
GGML_LOG_INFO("ggml-hex: Loading driver %s\n", drv_path.c_str());
|
||||
|
||||
fs::path path{ drv_path.c_str() };
|
||||
dl_handle_ptr handle { dl_load_library(path) };
|
||||
if (!handle) {
|
||||
GGML_LOG_ERROR("ggml-hex: failed to load %s: %s\n", path.u8string().c_str(), dl_error());
|
||||
return AEE_EUNABLETOLOAD;
|
||||
}
|
||||
|
||||
#define dlsym(drv, type, pfn, symbol, ignore) \
|
||||
do { \
|
||||
pfn = (type) dl_get_sym(drv, #symbol); \
|
||||
if (!ignore && nullptr == pfn) { \
|
||||
GGML_LOG_ERROR("ggml-hex: failed to dlsym %s\n", #symbol); \
|
||||
return AEE_EUNABLETOLOAD; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
dlsym(handle.get(), rpcmem_alloc_pfn_t, rpcmem_alloc_pfn, rpcmem_alloc, false);
|
||||
dlsym(handle.get(), rpcmem_alloc2_pfn_t, rpcmem_alloc2_pfn, rpcmem_alloc2, true);
|
||||
dlsym(handle.get(), rpcmem_free_pfn_t, rpcmem_free_pfn, rpcmem_free, false);
|
||||
dlsym(handle.get(), rpcmem_to_fd_pfn_t, rpcmem_to_fd_pfn, rpcmem_to_fd, false);
|
||||
dlsym(handle.get(), fastrpc_mmap_pfn_t, fastrpc_mmap_pfn, fastrpc_mmap, false);
|
||||
dlsym(handle.get(), fastrpc_munmap_pfn_t, fastrpc_munmap_pfn, fastrpc_munmap, false);
|
||||
dlsym(handle.get(), dspqueue_create_pfn_t, dspqueue_create_pfn, dspqueue_create, false);
|
||||
dlsym(handle.get(), dspqueue_close_pfn_t, dspqueue_close_pfn, dspqueue_close, false);
|
||||
dlsym(handle.get(), dspqueue_export_pfn_t, dspqueue_export_pfn, dspqueue_export, false);
|
||||
dlsym(handle.get(), dspqueue_write_pfn_t, dspqueue_write_pfn, dspqueue_write, false);
|
||||
dlsym(handle.get(), dspqueue_read_pfn_t, dspqueue_read_pfn, dspqueue_read, false);
|
||||
dlsym(handle.get(), remote_handle64_open_pfn_t, remote_handle64_open_pfn, remote_handle64_open, false);
|
||||
dlsym(handle.get(), remote_handle64_invoke_pfn_t, remote_handle64_invoke_pfn, remote_handle64_invoke, false);
|
||||
dlsym(handle.get(), remote_handle_control_pfn_t, remote_handle_control_pfn, remote_handle_control, false);
|
||||
dlsym(handle.get(), remote_handle64_control_pfn_t, remote_handle64_control_pfn, remote_handle64_control, false);
|
||||
dlsym(handle.get(), remote_session_control_pfn_t, remote_session_control_pfn, remote_session_control, false);
|
||||
dlsym(handle.get(), remote_handle64_close_pfn_t, remote_handle64_close_pfn, remote_handle64_close, false);
|
||||
|
||||
lib_cdsp_rpc_handle = std::move(handle);
|
||||
initialized = true;
|
||||
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
domain * get_domain(int domain_id) {
|
||||
int i = 0;
|
||||
int size = sizeof(supported_domains) / sizeof(domain);
|
||||
|
||||
for (i = 0; i < size; i++) {
|
||||
if (supported_domains[i].id == domain_id) {
|
||||
return &supported_domains[i];
|
||||
}
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
int get_hex_arch_ver(int domain, int * arch) {
|
||||
if (!remote_handle_control_pfn) {
|
||||
GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n");
|
||||
return AEE_EUNSUPPORTEDAPI;
|
||||
}
|
||||
|
||||
struct remote_dsp_capability arch_ver;
|
||||
arch_ver.domain = (uint32_t) domain;
|
||||
arch_ver.attribute_ID = ARCH_VER;
|
||||
arch_ver.capability = (uint32_t) 0;
|
||||
|
||||
int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver));
|
||||
if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) {
|
||||
GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n");
|
||||
return AEE_EUNSUPPORTEDAPI;
|
||||
}
|
||||
|
||||
if (err != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err);
|
||||
return err;
|
||||
}
|
||||
|
||||
switch (arch_ver.capability & 0xff) {
|
||||
case 0x68:
|
||||
*arch = 68;
|
||||
return 0;
|
||||
case 0x69:
|
||||
*arch = 69;
|
||||
return 0;
|
||||
case 0x73:
|
||||
*arch = 73;
|
||||
return 0;
|
||||
case 0x75:
|
||||
*arch = 75;
|
||||
return 0;
|
||||
case 0x79:
|
||||
*arch = 79;
|
||||
return 0;
|
||||
case 0x81:
|
||||
*arch = 81;
|
||||
return 0;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
121
ggml/src/ggml-hexagon/htp-drv.h
Normal file
121
ggml/src/ggml-hexagon/htp-drv.h
Normal file
@@ -0,0 +1,121 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
# pragma clang diagnostic ignored "-Wignored-attributes"
|
||||
#endif
|
||||
|
||||
#include <AEEStdErr.h>
|
||||
#include <rpcmem.h>
|
||||
#include <remote.h>
|
||||
#include <dspqueue.h>
|
||||
|
||||
#if defined(_WIN32) && !defined(__MINGW32__)
|
||||
# ifdef GGML_BACKEND_BUILD
|
||||
# define HTPDRV_API __declspec(dllexport) extern
|
||||
# else
|
||||
# define HTPDRV_API __declspec(dllimport) extern
|
||||
# endif
|
||||
#else
|
||||
# define HTPDRV_API __attribute__ ((visibility ("default"))) extern
|
||||
#endif
|
||||
|
||||
/* Offset to differentiate HLOS and Hexagon error codes.
|
||||
Stores the value of AEE_EOFFSET for Hexagon. */
|
||||
#ifndef DSP_OFFSET
|
||||
# define DSP_OFFSET 0x80000400
|
||||
#endif
|
||||
|
||||
/* Errno for connection reset by peer. */
|
||||
#ifndef ECONNRESET
|
||||
# ifdef __hexagon__
|
||||
# define ECONNRESET 104
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/* Abstraction of different OS specific sleep APIs.
|
||||
SLEEP accepts input in seconds. */
|
||||
#ifndef SLEEP
|
||||
# ifdef __hexagon__
|
||||
# define SLEEP(x) \
|
||||
{ /* Do nothing for simulator. */ \
|
||||
}
|
||||
# else
|
||||
# ifdef _WIN32
|
||||
# define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */
|
||||
# else
|
||||
# define SLEEP(x) sleep(x) /* sleep accepts input in seconds. */
|
||||
# endif
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/* Include windows specific header files. */
|
||||
#ifdef _WIN32
|
||||
# include <windows.h>
|
||||
# include <sysinfoapi.h>
|
||||
# define _CRT_SECURE_NO_WARNINGS 1
|
||||
# define _WINSOCK_DEPRECATED_NO_WARNINGS 1
|
||||
#endif
|
||||
|
||||
/* Includes and defines for all HLOS except windows */
|
||||
#if !defined(__hexagon__) && !defined(_WIN32)
|
||||
# include "unistd.h"
|
||||
|
||||
# include <sys/time.h>
|
||||
#endif
|
||||
|
||||
/* Includes and defines for Hexagon and all HLOS except Windows. */
|
||||
#if !defined(_WIN32)
|
||||
/* Weak reference to remote symbol for compilation. */
|
||||
# pragma weak remote_session_control
|
||||
# pragma weak remote_handle_control
|
||||
# pragma weak remote_handle64_control
|
||||
# pragma weak fastrpc_mmap
|
||||
# pragma weak fastrpc_munmap
|
||||
# pragma weak rpcmem_alloc2
|
||||
#endif
|
||||
|
||||
#if !defined(_WIN32)
|
||||
# pragma weak remote_system_request
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
# define DSPQUEUE_TIMEOUT DSPQUEUE_TIMEOUT_NONE
|
||||
#else
|
||||
# define DSPQUEUE_TIMEOUT 1000000
|
||||
#endif
|
||||
|
||||
/**
|
||||
* htpdrv_init API: driver interface entry point
|
||||
*
|
||||
* @return Return AEE error codes as defined in Hexagon SDK.
|
||||
*/
|
||||
HTPDRV_API int htpdrv_init(void);
|
||||
|
||||
/**
|
||||
* get_domain API: get domain struct from domain value.
|
||||
*
|
||||
* @param[in] domain value of a domain
|
||||
* @return Returns domain struct of the domain if it is supported or else
|
||||
* returns NULL.
|
||||
*
|
||||
*/
|
||||
HTPDRV_API domain * get_domain(int domain_id);
|
||||
|
||||
/**
|
||||
* get_hex_arch_ver API: query the Hexagon processor architecture version information
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @param[out] Arch version (73, 75, ...)
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
*/
|
||||
HTPDRV_API int get_hex_arch_ver(int domain, int * arch);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -1,454 +0,0 @@
|
||||
|
||||
#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
|
||||
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||
#pragma clang diagnostic ignored "-Wsign-compare"
|
||||
|
||||
#define GGML_COMMON_IMPL_C
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-common.h"
|
||||
#include "ggml-hexagon.h"
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#include "htp-utils.h"
|
||||
|
||||
#include <domain.h>
|
||||
#include <remote.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
domain * get_domain(int domain_id) {
|
||||
int i = 0;
|
||||
int size = sizeof(supported_domains) / sizeof(domain);
|
||||
|
||||
for (i = 0; i < size; i++) {
|
||||
if (supported_domains[i].id == domain_id) {
|
||||
return &supported_domains[i];
|
||||
}
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
bool is_valid_domain_id(int domain_id, int compute_only) {
|
||||
int i = 0;
|
||||
int size = sizeof(supported_domains) / sizeof(domain);
|
||||
|
||||
if (compute_only) {
|
||||
return is_CDSP(domain_id);
|
||||
}
|
||||
|
||||
for (i = 0; i < size; i++) {
|
||||
if (supported_domains[i].id == domain_id) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
int ss_info = 0;
|
||||
if (domain_type != NULL) {
|
||||
if (strcmp(domain_type, "LPASS") == 0) {
|
||||
ss_info = FASTRPC_LPASS;
|
||||
} else if (strcmp(domain_type, "HPASS") == 0) {
|
||||
ss_info = FASTRPC_HPASS;
|
||||
} else {
|
||||
ss_info = FASTRPC_NSP;
|
||||
}
|
||||
}
|
||||
system_req_payload req = { 0 };
|
||||
req.id = FASTRPC_GET_DOMAINS;
|
||||
req.sys.domains = NULL;
|
||||
fastrpc_domain * domain = NULL;
|
||||
if (ss_info != 0) {
|
||||
req.sys.flags = DOMAINS_LIST_FLAGS_SET_TYPE(req.sys.flags, ss_info);
|
||||
} else {
|
||||
req.sys.flags = 0;
|
||||
}
|
||||
#ifdef _WIN32
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
goto bail;
|
||||
#endif
|
||||
if (remote_system_request) {
|
||||
nErr = remote_system_request(&req);
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
// Allocate memory for domain-info array
|
||||
req.sys.max_domains = req.sys.num_domains;
|
||||
if ((req.sys.domains = calloc(req.sys.num_domains, sizeof(fastrpc_domain))) == NULL) {
|
||||
nErr = AEE_ENOMEMORY;
|
||||
GGML_LOG_ERROR("Unable to allocate memory for req.sys.domains");
|
||||
goto bail;
|
||||
}
|
||||
|
||||
nErr = remote_system_request(&req);
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
|
||||
for (int i = 0; i < req.sys.num_domains; i++) {
|
||||
// Verify that only requested type domains were returned
|
||||
domain = &req.sys.domains[i];
|
||||
if (domain->type != ss_info && domain_type != NULL) {
|
||||
nErr = -1;
|
||||
GGML_LOG_ERROR("Incorrect data received from remote_system_request.\n");
|
||||
goto bail;
|
||||
}
|
||||
}
|
||||
*domains_info = req.sys.domains;
|
||||
*num_domains = req.sys.num_domains;
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
goto bail;
|
||||
}
|
||||
bail:
|
||||
if (nErr && !req.sys.domains) {
|
||||
free(req.sys.domains);
|
||||
}
|
||||
return nErr;
|
||||
}
|
||||
|
||||
int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id) {
|
||||
int err = 0;
|
||||
remote_rpc_effective_domain_id_t sess = { 0 };
|
||||
|
||||
sess.domain_name = domain_name;
|
||||
sess.domain_name_len = strlen(domain_name);
|
||||
sess.session_id = session_id;
|
||||
|
||||
err = remote_session_control(FASTRPC_GET_EFFECTIVE_DOMAIN_ID, &sess, sizeof(sess));
|
||||
if (err) {
|
||||
GGML_LOG_ERROR("Error 0x%x: failed to get effective domain id for %s, session id %d\n", err, sess.domain_name,
|
||||
session_id);
|
||||
return err;
|
||||
}
|
||||
|
||||
*effec_domain_id = sess.effective_domain_id;
|
||||
return err;
|
||||
}
|
||||
|
||||
int get_dsp_support(int * domain) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
*domain = CDSP_DOMAIN_ID; // DSP domain default value is CDSP_DOMAIN_ID
|
||||
|
||||
if (remote_handle_control) {
|
||||
struct remote_dsp_capability dsp_capability_domain = { CDSP_DOMAIN_ID, DOMAIN_SUPPORT, 0 };
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
goto bail;
|
||||
}
|
||||
|
||||
if (dsp_capability_domain.capability == 0) {
|
||||
dsp_capability_domain.domain = ADSP_DOMAIN_ID; // Check for ADSP support.
|
||||
dsp_capability_domain.attribute_ID = DOMAIN_SUPPORT;
|
||||
dsp_capability_domain.capability = 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if (dsp_capability_domain.capability) {
|
||||
*domain = ADSP_DOMAIN_ID; // For targets like Agatti (not having cDSP), domain is ADSP_DOMAIN_ID
|
||||
}
|
||||
}
|
||||
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("\nget_dsp_support failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return nErr;
|
||||
}
|
||||
|
||||
int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
*capability = 0;
|
||||
|
||||
if (attr == VTCM_PAGE || attr == VTCM_COUNT) {
|
||||
} else {
|
||||
nErr = AEE_EBADPARM;
|
||||
GGML_LOG_ERROR("Unsupported attr. Only VTCM_PAGE and VTCM_COUNT supported\n");
|
||||
goto bail;
|
||||
}
|
||||
if (remote_handle_control) {
|
||||
if (domain == ADSP_DOMAIN_ID || domain == CDSP_DOMAIN_ID) {
|
||||
/*
|
||||
* Query the DSP for VTCM information
|
||||
* Since the ADSP does not have a dedicated VTCM, we expect the output to be 0
|
||||
*/
|
||||
struct remote_dsp_capability dsp_capability_vtcm_dsp;
|
||||
dsp_capability_vtcm_dsp.domain = (uint32_t) domain;
|
||||
dsp_capability_vtcm_dsp.attribute_ID = attr;
|
||||
dsp_capability_vtcm_dsp.capability = (uint32_t) 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_vtcm_dsp,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||
nErr = AEE_SUCCESS;
|
||||
goto bail;
|
||||
} else if (nErr == AEE_SUCCESS) {
|
||||
*capability = dsp_capability_vtcm_dsp.capability;
|
||||
} else {
|
||||
GGML_LOG_ERROR("\nget_vtcm_info failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
GGML_LOG_ERROR("Unsupported domain %d\n", domain);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return nErr;
|
||||
}
|
||||
|
||||
bool is_unsignedpd_supported(int domain_id) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
if (remote_handle_control) {
|
||||
struct remote_dsp_capability dsp_capability_domain = { domain_id, UNSIGNED_PD_SUPPORT, 0 };
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device. Falling back to signed pd.\n");
|
||||
return false;
|
||||
}
|
||||
if (nErr) {
|
||||
GGML_LOG_ERROR("\nERROR 0x%x: FastRPC Capability API failed. Falling back to signed pd.", nErr);
|
||||
return false;
|
||||
}
|
||||
if (dsp_capability_domain.capability == 1) {
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device. Falling back to signed pd.\n");
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool get_unsignedpd_support(void) {
|
||||
return is_unsignedpd_supported(CDSP_DOMAIN_ID);
|
||||
}
|
||||
|
||||
bool is_async_fastrpc_supported(int domain) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
if (remote_handle_control) {
|
||||
if (domain == CDSP_DOMAIN_ID) {
|
||||
/*
|
||||
* Query the DSP for ASYNC_FASTRPC_SUPPORT information
|
||||
* Async fastrpc is supported only on CDSP
|
||||
*/
|
||||
struct remote_dsp_capability dsp_capability_async_support;
|
||||
dsp_capability_async_support.domain = (uint32_t) domain;
|
||||
dsp_capability_async_support.attribute_ID = ASYNC_FASTRPC_SUPPORT;
|
||||
dsp_capability_async_support.capability = (uint32_t) 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_async_support,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||
nErr = AEE_SUCCESS;
|
||||
goto bail;
|
||||
} else if (dsp_capability_async_support.capability == 1) {
|
||||
return true;
|
||||
}
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("\nis_async_fastrpc_supported failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
GGML_LOG_ERROR("Async fastrpc is not supported on domain %d\n", domain);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_status_notification_supported(int domain) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
|
||||
if (remote_handle_control) {
|
||||
/*
|
||||
* Query the DSP for STATUS_NOTIFICATION_SUPPORT information
|
||||
* DSP User PD status notification Support
|
||||
*/
|
||||
struct remote_dsp_capability dsp_capability_status_notification_support;
|
||||
dsp_capability_status_notification_support.domain = (uint32_t) domain;
|
||||
dsp_capability_status_notification_support.attribute_ID = STATUS_NOTIFICATION_SUPPORT;
|
||||
dsp_capability_status_notification_support.capability = (uint32_t) 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_status_notification_support,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||
nErr = AEE_SUCCESS;
|
||||
goto bail;
|
||||
} else if (dsp_capability_status_notification_support.capability == 1) {
|
||||
return true;
|
||||
}
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("\nis_status_notification_supported failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return false;
|
||||
}
|
||||
|
||||
int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
*capability = 0;
|
||||
|
||||
if (attr != HMX_SUPPORT_SPATIAL && attr != HMX_SUPPORT_DEPTH) {
|
||||
nErr = AEE_EBADPARM;
|
||||
GGML_LOG_ERROR("Unsupported attr. Only HMX_SUPPORT_SPATIAL and HMX_SUPPORT_DEPTH supported\n");
|
||||
goto bail;
|
||||
}
|
||||
if (remote_handle_control) {
|
||||
if (domain == CDSP_DOMAIN_ID) {
|
||||
/*
|
||||
* Query the DSP for HMX SUPPORT information
|
||||
* HMX is supported on CDSP only
|
||||
*/
|
||||
struct remote_dsp_capability dsp_capability_hmx_dsp;
|
||||
dsp_capability_hmx_dsp.domain = (uint32_t) domain;
|
||||
dsp_capability_hmx_dsp.attribute_ID = attr;
|
||||
dsp_capability_hmx_dsp.capability = (uint32_t) 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hmx_dsp,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||
nErr = AEE_SUCCESS;
|
||||
goto bail;
|
||||
} else if (nErr == AEE_SUCCESS) {
|
||||
*capability = dsp_capability_hmx_dsp.capability;
|
||||
} else {
|
||||
GGML_LOG_ERROR("\nget_hmx_support_info failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
GGML_LOG_ERROR("HMX support is not there for domain %d\n", domain);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return nErr;
|
||||
}
|
||||
|
||||
int get_hex_arch_ver(int domain, int * arch) {
|
||||
if (!remote_handle_control) {
|
||||
GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n");
|
||||
return AEE_EUNSUPPORTEDAPI;
|
||||
}
|
||||
|
||||
struct remote_dsp_capability arch_ver;
|
||||
arch_ver.domain = (uint32_t) domain;
|
||||
arch_ver.attribute_ID = ARCH_VER;
|
||||
arch_ver.capability = (uint32_t) 0;
|
||||
|
||||
int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver));
|
||||
if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) {
|
||||
GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n");
|
||||
return AEE_EUNSUPPORTEDAPI;
|
||||
}
|
||||
|
||||
if (err != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err);
|
||||
return err;
|
||||
}
|
||||
|
||||
switch (arch_ver.capability & 0xff) {
|
||||
case 0x68:
|
||||
*arch = 68;
|
||||
return 0;
|
||||
case 0x69:
|
||||
*arch = 69;
|
||||
return 0;
|
||||
case 0x73:
|
||||
*arch = 73;
|
||||
return 0;
|
||||
case 0x75:
|
||||
*arch = 75;
|
||||
return 0;
|
||||
case 0x79:
|
||||
*arch = 79;
|
||||
return 0;
|
||||
case 0x81:
|
||||
*arch = 81;
|
||||
return 0;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
*capability = 0;
|
||||
|
||||
if (remote_handle_control) {
|
||||
if (domain == CDSP_DOMAIN_ID) {
|
||||
/*
|
||||
* Query the DSP for HVX SUPPORT information
|
||||
* HVX is supported on CDSP only
|
||||
*/
|
||||
struct remote_dsp_capability dsp_capability_hvx_dsp;
|
||||
dsp_capability_hvx_dsp.domain = (uint32_t) domain;
|
||||
dsp_capability_hvx_dsp.attribute_ID = attr;
|
||||
dsp_capability_hvx_dsp.capability = (uint32_t) 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hvx_dsp,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||
nErr = AEE_SUCCESS;
|
||||
goto bail;
|
||||
} else if (nErr == AEE_SUCCESS) {
|
||||
*capability = dsp_capability_hvx_dsp.capability;
|
||||
} else {
|
||||
GGML_LOG_ERROR("\nget_hvx_support_info failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
GGML_LOG_ERROR("HVX support is not available on domain %d\n", domain);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return nErr;
|
||||
}
|
||||
@@ -1,221 +0,0 @@
|
||||
#ifndef HTP_UTILS_H
|
||||
#define HTP_UTILS_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include <AEEStdErr.h>
|
||||
#include <inttypes.h>
|
||||
#include <remote.h>
|
||||
#include <rpcmem.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
/* Offset to differentiate HLOS and Hexagon error codes.
|
||||
Stores the value of AEE_EOFFSET for Hexagon. */
|
||||
#ifndef DSP_OFFSET
|
||||
# define DSP_OFFSET 0x80000400
|
||||
#endif
|
||||
|
||||
/* Errno for connection reset by peer. */
|
||||
#ifndef ECONNRESET
|
||||
# ifdef __hexagon__
|
||||
# define ECONNRESET 104
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/* Abstraction of different OS specific sleep APIs.
|
||||
SLEEP accepts input in seconds. */
|
||||
#ifndef SLEEP
|
||||
# ifdef __hexagon__
|
||||
# define SLEEP(x) \
|
||||
{ /* Do nothing for simulator. */ \
|
||||
}
|
||||
# else
|
||||
# ifdef _WINDOWS
|
||||
# define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */
|
||||
# else
|
||||
# define SLEEP(x) sleep(x) /* sleep accepts input in seconds. */
|
||||
# endif
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/* Include windows specific header files. */
|
||||
#ifdef _WINDOWS
|
||||
# include <sysinfoapi.h>
|
||||
# include <windows.h>
|
||||
# define _CRT_SECURE_NO_WARNINGS 1
|
||||
# define _WINSOCK_DEPRECATED_NO_WARNINGS 1
|
||||
/* Including this file for custom implementation of getopt function. */
|
||||
# include "getopt_custom.h"
|
||||
#endif
|
||||
|
||||
/* Includes and defines for all HLOS except windows */
|
||||
#if !defined(__hexagon__) && !defined(_WINDOWS)
|
||||
# include "unistd.h"
|
||||
|
||||
# include <sys/time.h>
|
||||
#endif
|
||||
|
||||
/* Includes and defines for Hexagon and all HLOS except Windows. */
|
||||
#if !defined(_WINDOWS)
|
||||
/* Weak reference to remote symbol for compilation. */
|
||||
# pragma weak remote_session_control
|
||||
# pragma weak remote_handle_control
|
||||
# pragma weak remote_handle64_control
|
||||
# pragma weak fastrpc_mmap
|
||||
# pragma weak fastrpc_munmap
|
||||
# pragma weak rpcmem_alloc2
|
||||
#endif
|
||||
|
||||
#if !defined(_WINDOWS)
|
||||
# pragma weak remote_system_request
|
||||
#endif
|
||||
/**
|
||||
* Wrapper for FastRPC Capability API: query DSP support.
|
||||
*
|
||||
* @param[out] domain pointer to supported domain.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*/
|
||||
int get_dsp_support(int * domain);
|
||||
|
||||
/**
|
||||
* Wrapper for FastRPC Capability API: query VTCM information.
|
||||
*
|
||||
* @param[in] domain value of domain in the queried.
|
||||
* @param[out] capability capability value of the attribute queried.
|
||||
* @param[in] attr value of the attribute to the queried.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*/
|
||||
int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr);
|
||||
|
||||
/**
|
||||
* Wrapper for FastRPC Capability API: query unsigned pd support on CDSP domain.
|
||||
*
|
||||
* @return true if unsigned pd is supported.
|
||||
* false if unsigned pd is not supported, capability query failed.
|
||||
*/
|
||||
|
||||
bool get_unsignedpd_support(void);
|
||||
|
||||
/**
|
||||
* Wrapper for FastRPC Capability API: query unsigned pd support.
|
||||
*
|
||||
* @param[in] domain value of domain in the queried.
|
||||
* @return true if unsigned pd is supported.
|
||||
* false if unsigned pd is not supported, capability query failed.
|
||||
*/
|
||||
|
||||
bool is_unsignedpd_supported(int domain_id);
|
||||
|
||||
/**
|
||||
* is_valid_domain_id API: query a domain id is valid.
|
||||
*
|
||||
* @param[in] domain value of domain in the queried.
|
||||
* @param[in] compute_only value of domain is only compared with CDSP domains supported by the target when enabled.
|
||||
* @return true if value of domain is valid.
|
||||
* false if value of domain is not valid.
|
||||
*/
|
||||
|
||||
bool is_valid_domain_id(int domain_id, int compute_only);
|
||||
|
||||
/**
|
||||
* get_domain API: get domain struct from domain value.
|
||||
*
|
||||
* @param[in] domain value of a domain
|
||||
* @return Returns domain struct of the domain if it is supported or else
|
||||
* returns NULL.
|
||||
*
|
||||
*/
|
||||
|
||||
domain * get_domain(int domain_id);
|
||||
|
||||
/**
|
||||
* get_domains_info API: get information for all the domains available on the device
|
||||
*
|
||||
* @param[in] domain_type pointer to domain type
|
||||
* @param[in] num_domains pointer to number of domains
|
||||
* @param[in] domains_info pointer to save discovered domains information.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
* It is user's responsibility to free the memory used to store the domains info whose address is present in domains_info before closing the application.
|
||||
*
|
||||
*/
|
||||
|
||||
int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info);
|
||||
|
||||
/**
|
||||
* get_effective_domain_id API: get effective domain id for given session id
|
||||
*
|
||||
* @param[in] domain_name pointer to domain name
|
||||
* @param[in] session_id
|
||||
* @param[in] effec_domain_id pointer to save obtained effective domain id.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
*/
|
||||
|
||||
int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id);
|
||||
|
||||
/**
|
||||
* is_async_fastrpc_supported API: query a domain id has async fastrpc supported or not
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @return Returns true or false stating support of Async FastRPC
|
||||
*
|
||||
*/
|
||||
|
||||
bool is_async_fastrpc_supported(int domain_id);
|
||||
|
||||
/**
|
||||
* is_status_notification_supported API: query the DSP for STATUS_NOTIFICATION_SUPPORT information
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @return Returns true or false stating status notification support information
|
||||
*
|
||||
*/
|
||||
bool is_status_notification_supported(int domain_id);
|
||||
|
||||
/**
|
||||
* get_hmx_support_info API: query the DSP for HMX SUPPORT information
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @param[out] capability capability value of the attribute queried.
|
||||
* @param[in] attr value of the attribute to the queried.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
*/
|
||||
int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr);
|
||||
|
||||
/**
|
||||
* get_hex_arch_ver API: query the Hexagon processor architecture version information
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @param[out] Arch version (73, 75, ...)
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
*/
|
||||
int get_hex_arch_ver(int domain, int * arch);
|
||||
|
||||
/**
|
||||
* get_hvx_support_info API: query the DSP for HVX SUPPORT information
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @param[out] capability capability value of the attribute queried.
|
||||
* @param[in] attr value of the attribute to the queried.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
*/
|
||||
int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif //DSP_CAPABILITIES_UTILS_H
|
||||
79
ggml/src/ggml-hexagon/libdl.h
Normal file
79
ggml/src/ggml-hexagon/libdl.h
Normal file
@@ -0,0 +1,79 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef _WIN32
|
||||
# define WIN32_LEAN_AND_MEAN
|
||||
# ifndef NOMINMAX
|
||||
# define NOMINMAX
|
||||
# endif
|
||||
# include <windows.h>
|
||||
# include <winevt.h>
|
||||
#else
|
||||
# include <dlfcn.h>
|
||||
# include <unistd.h>
|
||||
#endif
|
||||
#include <filesystem>
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
using dl_handle = std::remove_pointer_t<HMODULE>;
|
||||
|
||||
struct dl_handle_deleter {
|
||||
void operator()(HMODULE handle) {
|
||||
FreeLibrary(handle);
|
||||
}
|
||||
};
|
||||
|
||||
static inline dl_handle * dl_load_library(const fs::path & path) {
|
||||
// suppress error dialogs for missing DLLs
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
|
||||
HMODULE handle = LoadLibraryW(path.wstring().c_str());
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
static inline void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
|
||||
void * p = (void *) GetProcAddress(handle, name);
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
return p;
|
||||
}
|
||||
|
||||
static inline const char * dl_error() {
|
||||
return "";
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
using dl_handle = void;
|
||||
|
||||
struct dl_handle_deleter {
|
||||
void operator()(void * handle) {
|
||||
dlclose(handle);
|
||||
}
|
||||
};
|
||||
|
||||
static inline dl_handle * dl_load_library(const fs::path & path) {
|
||||
dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
|
||||
return handle;
|
||||
}
|
||||
|
||||
static inline void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
return dlsym(handle, name);
|
||||
}
|
||||
|
||||
static inline const char * dl_error() {
|
||||
const char *rslt = dlerror();
|
||||
return rslt != nullptr ? rslt : "";
|
||||
}
|
||||
|
||||
#endif
|
||||
38
ggml/src/ggml-hexagon/libggml-htp.inf
Normal file
38
ggml/src/ggml-hexagon/libggml-htp.inf
Normal file
@@ -0,0 +1,38 @@
|
||||
[Version]
|
||||
Signature = "$WINDOWS NT$"
|
||||
Class = ComputeAccelerator
|
||||
ClassGuid = {F01A9D53-3FF6-48D2-9F97-C8A7004BE10C}
|
||||
Provider = %GGML%
|
||||
DriverVer = 01/01/2026,1.0.0.0
|
||||
CatalogFile = libggml-htp.cat
|
||||
PnpLockDown = 1
|
||||
|
||||
[DestinationDirs]
|
||||
Drivers_Dir = 6
|
||||
|
||||
[SourceDisksNames]
|
||||
1 = %DiskId%
|
||||
|
||||
[SourceDisksFiles]
|
||||
libggml-htp-v68.so = 1
|
||||
libggml-htp-v69.so = 1
|
||||
libggml-htp-v73.so = 1
|
||||
libggml-htp-v75.so = 1
|
||||
libggml-htp-v81.so = 1
|
||||
|
||||
[ControlFlags]
|
||||
ExcludeFromSelect = *
|
||||
|
||||
[DefaultInstall.NTarm64]
|
||||
CopyFiles=Drivers_Dir
|
||||
|
||||
[Drivers_Dir]
|
||||
libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
libggml-htp-v81.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
|
||||
[Strings]
|
||||
GGML = 'GGML'
|
||||
DiskId = 'GGML HTP library'
|
||||
@@ -62,6 +62,8 @@ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
|
||||
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
|
||||
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/mmf*.cu")
|
||||
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
||||
|
||||
if (GGML_CUDA_FA_ALL_QUANTS)
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
|
||||
|
||||
@@ -101,6 +101,8 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mm_f32_f32_l4_lm
|
||||
mul_mm_f16_f32_l4_lm
|
||||
mul_mm_q8_0_f32_l4_lm
|
||||
mul_mm_q8_0_f32_8x4
|
||||
gemv_noshuffle_general_q8_0_f32
|
||||
mul
|
||||
norm
|
||||
relu
|
||||
|
||||
@@ -226,7 +226,8 @@ static ADRENO_GPU_GEN get_adreno_gpu_gen(const char *device_name) {
|
||||
return ADRENO_GPU_GEN::A7X;
|
||||
}
|
||||
|
||||
if (strstr(device_name, "830")) {
|
||||
if (strstr(device_name, "830") ||
|
||||
strstr(device_name, "840")) {
|
||||
return ADRENO_GPU_GEN::A8X;
|
||||
}
|
||||
|
||||
@@ -529,7 +530,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
|
||||
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
|
||||
cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
|
||||
cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0;
|
||||
cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
|
||||
cl_kernel kernel_convert_block_q4_0_noshuffle;
|
||||
cl_kernel kernel_restore_block_q4_0_noshuffle;
|
||||
@@ -696,6 +697,8 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096;
|
||||
cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096;
|
||||
cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096;
|
||||
cl_kernel kernel_mul_mm_q8_0_f32_8x4;
|
||||
cl_kernel CL_mul_mat_vec_q8_0_f32;
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
|
||||
void free() {
|
||||
@@ -894,6 +897,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q8_0_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0_trans", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
@@ -2290,6 +2294,46 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mm_q8_0_f32_8x4
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src_q8_8x4_gemm {
|
||||
#include "mul_mm_q8_0_f32_8x4.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src_q8_8x4_gemm = read_file("mul_mm_q8_0_f32_8x4.cl");
|
||||
#endif
|
||||
backend_ctx->program_CL_gemm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_q8_8x4_gemm.c_str(), compile_opts);
|
||||
CL_CHECK((backend_ctx->kernel_mul_mm_q8_0_f32_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mm_q8_0_f32_8x4", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// gemv_noshuffle_general_q8_0_f32
|
||||
{
|
||||
std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std +
|
||||
" -cl-mad-enable "
|
||||
" -DSIMDGROUP_WIDTH=" +
|
||||
std::to_string(backend_ctx->adreno_wave_size);
|
||||
if (backend_ctx->has_vector_subgroup_broadcast) {
|
||||
CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT ";
|
||||
}
|
||||
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src_CL_gemv_general {
|
||||
#include "gemv_noshuffle_general_q8_0_f32.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_general_q8_0_f32.cl");
|
||||
#endif
|
||||
|
||||
cl_program prog = build_program_from_source(
|
||||
backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->CL_mul_mat_vec_q8_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
std::string CL_moe_compile_opts = std::string("-cl-std=") + opencl_c_std +
|
||||
" -cl-mad-enable "
|
||||
" -cl-fast-relaxed-math";
|
||||
@@ -3745,6 +3789,15 @@ inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ct
|
||||
return ((strstr(tensor->name, "ffn") != NULL) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0);
|
||||
}
|
||||
|
||||
inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
|
||||
|
||||
bool adreno_kernel = use_adreno_kernels(backend_ctx, tensor);
|
||||
|
||||
size_t elem_num = tensor->ne[0] * tensor->ne[1] * tensor->ne[2] * tensor->ne[3];
|
||||
|
||||
return ((elem_num < 128 * 1024 * 1024) && adreno_kernel); // max element num: 2**27
|
||||
}
|
||||
|
||||
static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device);
|
||||
|
||||
@@ -4159,6 +4212,130 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
|
||||
tensor->extra = extra;
|
||||
|
||||
// Transpose the weights and scales
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
if (enable_adreno_trans_weight(backend_ctx, tensor)) {
|
||||
|
||||
int M = tensor->ne[1]; // ne01
|
||||
int K = tensor->ne[0]; // ne00
|
||||
|
||||
GGML_ASSERT(K % 32 == 0);
|
||||
GGML_ASSERT(M % 4 == 0);
|
||||
GGML_ASSERT(tensor->ne[2] == 1);
|
||||
GGML_ASSERT(tensor->ne[3] == 1);
|
||||
|
||||
// Transpose weights
|
||||
size_t q_size_bytes = K * M / 4 * sizeof(float);
|
||||
cl_buffer_region region;
|
||||
region.origin = 0;
|
||||
region.size = q_size_bytes;
|
||||
cl_mem qT_d = clCreateSubBuffer(
|
||||
backend_ctx->prealloc_quant_trans.buffer,
|
||||
0,
|
||||
CL_BUFFER_CREATE_TYPE_REGION,
|
||||
®ion,
|
||||
&err);
|
||||
CL_CHECK(err);
|
||||
|
||||
cl_mem q_d_image1D;
|
||||
cl_mem qT_d_image1D;
|
||||
|
||||
cl_image_format img_fmt_1d;
|
||||
cl_image_desc img_desc_1d;
|
||||
|
||||
img_fmt_1d = { CL_RGBA, CL_FLOAT };
|
||||
memset(&img_desc_1d, 0, sizeof(img_desc_1d));
|
||||
img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc_1d.image_width = M * K / 4 / 4;
|
||||
img_desc_1d.buffer = extra->q;
|
||||
q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
img_fmt_1d = { CL_RGBA, CL_FLOAT };
|
||||
memset(&img_desc_1d, 0, sizeof(img_desc_1d));
|
||||
img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc_1d.image_width = M * K / 4 / 4;
|
||||
img_desc_1d.buffer = qT_d;
|
||||
qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
int height_q = M / 4;
|
||||
int width_q = K / 4 / 4;
|
||||
kernel = backend_ctx->kernel_transpose_32;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qT_d_image1D));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_q));
|
||||
|
||||
size_t local_size_q[3] = {4, 16, 1};
|
||||
size_t global_size_q[3] = {static_cast<size_t>(width_q), static_cast<size_t>(height_q), 1};
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_q, local_size_q, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
|
||||
// Transpose scales
|
||||
size_t d_size_bytes = M * (K / 32) * 2;
|
||||
region.origin = 0;
|
||||
region.size = d_size_bytes;
|
||||
cl_mem dT_d = clCreateSubBuffer(
|
||||
backend_ctx->prealloc_scales_trans.buffer,
|
||||
0,
|
||||
CL_BUFFER_CREATE_TYPE_REGION,
|
||||
®ion,
|
||||
&err);
|
||||
CL_CHECK(err);
|
||||
|
||||
cl_mem d_d_image1D;
|
||||
cl_mem dT_d_image1D;
|
||||
|
||||
memset(&img_desc_1d, 0, sizeof(img_desc_1d));
|
||||
img_fmt_1d = { CL_R, CL_HALF_FLOAT };
|
||||
img_desc_1d.image_width = M * K / 32;
|
||||
img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc_1d.buffer = extra->d;
|
||||
d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT };
|
||||
memset(&img_desc_1d, 0, sizeof(img_desc_1d));
|
||||
img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc_1d.image_width = M * K / 32 / 4;
|
||||
img_desc_1d.buffer = dT_d;
|
||||
dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
int height_s = M / 4;
|
||||
int width_s = K / 32;
|
||||
|
||||
kernel = backend_ctx->kernel_transpose_16_4x1;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_s));
|
||||
|
||||
size_t local_size_s[3] = {4, 16, 1};
|
||||
size_t global_size_s[3] = {static_cast<size_t>(width_s), static_cast<size_t>(height_s), 1};
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_s, local_size_s, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
|
||||
// copy transposed buffer contents to original buffers
|
||||
CL_CHECK(clEnqueueCopyBuffer(queue, qT_d, extra->q, 0, 0, q_size_bytes, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
|
||||
CL_CHECK(clEnqueueCopyBuffer(queue, dT_d, extra->d, 0, 0, d_size_bytes, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
|
||||
CL_CHECK(clReleaseMemObject(qT_d));
|
||||
CL_CHECK(clReleaseMemObject(dT_d));
|
||||
|
||||
CL_CHECK(clReleaseMemObject(q_d_image1D));
|
||||
CL_CHECK(clReleaseMemObject(d_d_image1D));
|
||||
CL_CHECK(clReleaseMemObject(qT_d_image1D));
|
||||
CL_CHECK(clReleaseMemObject(dT_d_image1D));
|
||||
} // end transpose
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
|
||||
return;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_Q6_K) {
|
||||
@@ -4448,6 +4625,36 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||
ggml_nbytes(tensor), NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
if (enable_adreno_trans_weight(backend_ctx, tensor)) {
|
||||
cl_kernel kernel = backend_ctx->kernel_restore_block_q8_0_trans;
|
||||
|
||||
int ne00 = tensor->ne[0];
|
||||
int ne01 = tensor->ne[1];
|
||||
GGML_ASSERT(tensor->ne[2] == 1); // ???
|
||||
GGML_ASSERT(tensor->ne[3] == 1); // ???
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01));
|
||||
|
||||
size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), 1, 1};
|
||||
size_t local_work_size[3] = {64, 1, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
|
||||
global_work_size, local_work_size, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
|
||||
CL_CHECK(clEnqueueReadBuffer(
|
||||
queue, data_device, CL_TRUE, offset,
|
||||
size, data, 0, NULL, NULL));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
cl_kernel kernel = backend_ctx->kernel_restore_block_q8_0;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d));
|
||||
@@ -7947,6 +8154,252 @@ static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_ten
|
||||
CL_CHECK(clReleaseMemObject(D_sub_buffer));
|
||||
}
|
||||
|
||||
static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
GGML_ASSERT(src1);
|
||||
GGML_ASSERT(src1->extra);
|
||||
GGML_ASSERT(dst);
|
||||
GGML_ASSERT(dst->extra);
|
||||
|
||||
const enum ggml_type src0t = src0->type;
|
||||
const enum ggml_type src1t = src1->type;
|
||||
|
||||
GGML_ASSERT(src0t == GGML_TYPE_Q8_0);
|
||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
|
||||
ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra;
|
||||
|
||||
GGML_ASSERT(src1->view_offs == 0);
|
||||
GGML_ASSERT(dst->view_offs == 0);
|
||||
|
||||
const int ne00 = src0->ne[0];
|
||||
const int ne01 = src0->ne[1];
|
||||
const int ne02 = src0->ne[2];
|
||||
|
||||
const int ne10 = src1->ne[0];
|
||||
const int ne12 = src1->ne[2];
|
||||
|
||||
const int ne0 = dst->ne[0];
|
||||
const int ne1 = dst->ne[1];
|
||||
|
||||
GGML_ASSERT(ne00 == ne10);
|
||||
GGML_ASSERT((ne00 % 32) == 0);
|
||||
GGML_ASSERT(ne0 == ne01);
|
||||
|
||||
cl_context context = backend_ctx->context;
|
||||
cl_kernel kernel;
|
||||
|
||||
// init CL objects
|
||||
cl_int status;
|
||||
cl_image_format img_fmt_1d;
|
||||
cl_image_desc img_desc_1d;
|
||||
cl_buffer_region region;
|
||||
cl_mem A_image1d;
|
||||
cl_mem B_image1d;
|
||||
cl_mem B_sub_buffer;
|
||||
cl_mem S_image1d;
|
||||
|
||||
cl_mem D_image1d;
|
||||
cl_mem D_sub_buffer;
|
||||
|
||||
int M = ne01;
|
||||
int N = ne1;
|
||||
int K = ne00;
|
||||
|
||||
// create an image for A
|
||||
img_fmt_1d = { CL_R, CL_FLOAT};
|
||||
memset(&img_desc_1d, 0, sizeof(img_desc_1d));
|
||||
img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc_1d.image_width = M * K / 4; // Divide by 4 for char -> float
|
||||
img_desc_1d.buffer = extra0_q8_0->q;
|
||||
A_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// create an image for Scale
|
||||
img_fmt_1d = { CL_R, CL_HALF_FLOAT};
|
||||
memset(&img_desc_1d, 0, sizeof(img_desc_1d));
|
||||
img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc_1d.image_width = M * K / 32; // Block size is 32
|
||||
img_desc_1d.buffer = extra0_q8_0->d;
|
||||
S_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// create a sub_buffer for B
|
||||
region.origin = (extra1->offset); // + src1->view_offs);
|
||||
region.size = K * N * sizeof(float);
|
||||
B_sub_buffer = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// create an image for B from sub_buffer: RGBA (OCL)
|
||||
img_fmt_1d = {CL_RGBA, CL_FLOAT};
|
||||
memset(&img_desc_1d, 0, sizeof(img_desc_1d));
|
||||
img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc_1d.image_width = K * N / 4;
|
||||
img_desc_1d.buffer = B_sub_buffer;
|
||||
B_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// Create subbuffer and image1d_buffer for dst
|
||||
region.origin = (extrad->offset); // + dst->view_offs;
|
||||
region.size = M * N * sizeof(float);
|
||||
D_sub_buffer = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
img_fmt_1d = {CL_R, CL_FLOAT};
|
||||
memset(&img_desc_1d, 0, sizeof(img_desc_1d));
|
||||
img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc_1d.image_width = M * N;
|
||||
img_desc_1d.buffer = D_sub_buffer;
|
||||
D_image1d = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
size_t local_work_size[3] = {1, 1, 1};
|
||||
size_t global_work_size[3] = {1, 1, 1};
|
||||
|
||||
if (N == 1) {
|
||||
kernel = backend_ctx->CL_mul_mat_vec_q8_0_f32;
|
||||
|
||||
int r2 = 1;
|
||||
int r3 = 1;
|
||||
cl_uint k_arg = 0;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extra0_q8_0->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_image1d));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extra1->offset));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extrad->offset));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne10));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne0));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r3));
|
||||
|
||||
size_t wavesize = backend_ctx->adreno_wave_size;
|
||||
local_work_size[0] = wavesize;
|
||||
local_work_size[1] = 4; // reduce factor
|
||||
local_work_size[2] = 1;
|
||||
|
||||
global_work_size[0] = ((M + wavesize - 1) / wavesize) * wavesize;
|
||||
global_work_size[1] = 4; // reduce factor
|
||||
global_work_size[2] = 1;
|
||||
} else {
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
cl_mem B_image1d_trans = nullptr;
|
||||
// for B transpose
|
||||
cl_mem B_d = nullptr;
|
||||
int padding;
|
||||
|
||||
//how many extra elements beyond multiple of 8
|
||||
int extra_elements = N % 8;
|
||||
|
||||
//how much padding to add
|
||||
padding = 0;
|
||||
if (extra_elements > 0){
|
||||
padding = 8 - extra_elements;
|
||||
}
|
||||
|
||||
// Specify the starting offset (in bytes)
|
||||
region.origin = 0;
|
||||
// Specify the size of the sub-buffer (divide by 2 for FP16)
|
||||
region.size = K * (N + padding) * sizeof(float)/2;
|
||||
backend_ctx->prealloc_act_trans.allocate(context, region.size);
|
||||
B_d = clCreateSubBuffer(
|
||||
backend_ctx->prealloc_act_trans.buffer,
|
||||
0,
|
||||
CL_BUFFER_CREATE_TYPE_REGION,
|
||||
®ion,
|
||||
&status);
|
||||
CL_CHECK(status);
|
||||
|
||||
cl_image_format image_format_B_d_output = { CL_RGBA, CL_HALF_FLOAT }; //(CL_HALF_FLOAT for FP16)
|
||||
cl_image_desc image_desc_B_d_output = {
|
||||
CL_MEM_OBJECT_IMAGE1D_BUFFER,
|
||||
static_cast<size_t>(K * (N + padding)/4),
|
||||
0, 0, 0, 0, 0, 0, 0, { B_d }
|
||||
};
|
||||
B_image1d_trans = clCreateImage(
|
||||
context,
|
||||
0,
|
||||
&image_format_B_d_output,
|
||||
&image_desc_B_d_output,
|
||||
NULL,
|
||||
&status);
|
||||
CL_CHECK(status);
|
||||
|
||||
int height_B = N/4;
|
||||
if (height_B == 0) {
|
||||
height_B = 1;
|
||||
}
|
||||
int width_B = K/4;
|
||||
int padded_height_B = (N + padding)/4;
|
||||
|
||||
kernel = backend_ctx->kernel_transpose_32_16;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &B_image1d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &B_image1d_trans));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B));
|
||||
|
||||
size_t local_size_t[2] = { 1, 16 };
|
||||
size_t global_size_t[2] = {
|
||||
static_cast<size_t>(width_B),
|
||||
static_cast<size_t>(padded_height_B)
|
||||
};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_size_t, local_size_t, dst);
|
||||
|
||||
kernel = backend_ctx->kernel_mul_mm_q8_0_f32_8x4;
|
||||
|
||||
int N_with_padding = N + padding;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &B_image1d_trans));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &K));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &M));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &N_with_padding));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &N));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd));
|
||||
|
||||
global_work_size[0] = (size_t)(N + 7) / 8;
|
||||
global_work_size[1] = (size_t)(M + 3) / 4;
|
||||
global_work_size[2] = 1;
|
||||
|
||||
local_work_size[0] = 2;
|
||||
local_work_size[1] = 128;
|
||||
local_work_size[2] = 1;
|
||||
}
|
||||
|
||||
// enqueue kernel with profiling
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
|
||||
// deallocate sub buffers and images
|
||||
CL_CHECK(clReleaseMemObject(A_image1d));
|
||||
CL_CHECK(clReleaseMemObject(B_sub_buffer));
|
||||
CL_CHECK(clReleaseMemObject(B_image1d));
|
||||
CL_CHECK(clReleaseMemObject(S_image1d));
|
||||
CL_CHECK(clReleaseMemObject(D_sub_buffer));
|
||||
CL_CHECK(clReleaseMemObject(D_image1d));
|
||||
#else
|
||||
GGML_UNUSED(src0);
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
@@ -8064,6 +8517,13 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
int padding;
|
||||
// <--------------------------------------------> //
|
||||
|
||||
// q8_0 x fp32
|
||||
if (src0t == GGML_TYPE_Q8_0 && src1t == GGML_TYPE_F32 &&
|
||||
enable_adreno_trans_weight(backend_ctx, src0)) {
|
||||
ggml_cl_mul_mat_q8_0_f32_adreno(backend, src0, src1, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
// q4_0 x fp32
|
||||
if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) {
|
||||
// TODO: remove duplicate definitions of image description + format -- move to top
|
||||
|
||||
@@ -274,6 +274,37 @@ kernel void kernel_restore_block_q8_0(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_q8_0_trans(
|
||||
global uchar * src_q,
|
||||
global half * src_d,
|
||||
global block_q8_0 * dst,
|
||||
uint ne00,
|
||||
uint ne01
|
||||
){
|
||||
uint num_blk_per_row = ne00 / QK8_0;
|
||||
|
||||
global block_q8_0 * b = (global block_q8_0 *) dst + get_global_id(0) * num_blk_per_row;
|
||||
global uchar * q = (global uchar *) src_q + get_global_id(0) * 4; // 4 8-bit packed
|
||||
global half * d = (global half *) src_d + get_global_id(0);
|
||||
|
||||
for (uint blk = 0; blk < num_blk_per_row; blk++) {
|
||||
b->d = *d;
|
||||
|
||||
for (uint i = 0; i < QK8_0; i+=4) {
|
||||
b->qs[i] = q[0];
|
||||
b->qs[i+1] = q[1];
|
||||
b->qs[i+2] = q[2];
|
||||
b->qs[i+3] = q[3];
|
||||
|
||||
q += 4 * ne01; // M stride
|
||||
}
|
||||
|
||||
d += ne01;
|
||||
|
||||
b++;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// kernel_convert_block_q6_K
|
||||
// Convert the block_q6_K format to 3 separate arrays (AOS -> SOA).
|
||||
|
||||
195
ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl
Normal file
195
ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl
Normal file
@@ -0,0 +1,195 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
|
||||
#ifdef cl_qcom_reqd_sub_group_size
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#endif
|
||||
|
||||
#define QK8_0 32
|
||||
#define N_SIMDGROUP 4
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_1(total_sums, bits8, scale, y) \
|
||||
float shared_y; \
|
||||
char elem; \
|
||||
\
|
||||
shared_y = sub_group_broadcast(y.s0, 0); \
|
||||
elem = (char)(bits8.s0 & 0x000000FF); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 0); \
|
||||
elem = (char)((bits8.s0 & 0x0000FF00) >> 8); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 0); \
|
||||
elem = (char)((bits8.s0 & 0x00FF0000) >> 16); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 0); \
|
||||
elem = (char)((bits8.s0 & 0xFF000000) >> 24); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
\
|
||||
shared_y = sub_group_broadcast(y.s4, 0); \
|
||||
elem = (char)(bits8.s1 & 0x000000FF); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 0); \
|
||||
elem = (char)((bits8.s1 & 0x0000FF00) >> 8); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 0); \
|
||||
elem = (char)((bits8.s1 & 0x00FF0000) >> 16); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 0); \
|
||||
elem = (char)((bits8.s1 & 0xFF000000) >> 24); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
\
|
||||
shared_y = sub_group_broadcast(y.s0, 1); \
|
||||
elem = (char)(bits8.s2 & 0x000000FF); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 1); \
|
||||
elem = (char)((bits8.s2 & 0x0000FF00) >> 8); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 1); \
|
||||
elem = (char)((bits8.s2 & 0x00FF0000) >> 16); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 1); \
|
||||
elem = (char)((bits8.s2 & 0xFF000000) >> 24); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
\
|
||||
shared_y = sub_group_broadcast(y.s4, 1); \
|
||||
elem = (char)(bits8.s3 & 0x000000FF); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 1); \
|
||||
elem = (char)((bits8.s3 & 0x0000FF00) >> 8); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 1); \
|
||||
elem = (char)((bits8.s3 & 0x00FF0000) >> 16); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 1); \
|
||||
elem = (char)((bits8.s3 & 0xFF000000) >> 24); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
\
|
||||
shared_y = sub_group_broadcast(y.s0, 2); \
|
||||
elem = (char)(bits8.s4 & 0x000000FF); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 2); \
|
||||
elem = (char)((bits8.s4 & 0x0000FF00) >> 8); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 2); \
|
||||
elem = (char)((bits8.s4 & 0x00FF0000) >> 16); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 2); \
|
||||
elem = (char)((bits8.s4 & 0xFF000000) >> 24); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
\
|
||||
shared_y = sub_group_broadcast(y.s4, 2); \
|
||||
elem = (char)(bits8.s5 & 0x000000FF); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 2); \
|
||||
elem = (char)((bits8.s5 & 0x0000FF00) >> 8); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 2); \
|
||||
elem = (char)((bits8.s5 & 0x00FF0000) >> 16); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 2); \
|
||||
elem = (char)((bits8.s5 & 0xFF000000) >> 24); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
\
|
||||
shared_y = sub_group_broadcast(y.s0, 3); \
|
||||
elem = (char)(bits8.s6 & 0x000000FF); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 3); \
|
||||
elem = (char)((bits8.s6 & 0x0000FF00) >> 8); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 3); \
|
||||
elem = (char)((bits8.s6 & 0x00FF0000) >> 16); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 3); \
|
||||
elem = (char)((bits8.s6 & 0xFF000000) >> 24); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
\
|
||||
shared_y = sub_group_broadcast(y.s4, 3); \
|
||||
elem = (char)(bits8.s7 & 0x000000FF); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 3); \
|
||||
elem = (char)((bits8.s7 & 0x0000FF00) >> 8); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 3); \
|
||||
elem = (char)((bits8.s7 & 0x00FF0000) >> 16); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 3); \
|
||||
elem = (char)((bits8.s7 & 0xFF000000) >> 24); \
|
||||
total_sums += convert_int(elem) * scale * shared_y; \
|
||||
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
__kernel void kernel_gemv_noshuffle(
|
||||
__read_only image1d_buffer_t src0_q, // quantized A
|
||||
global half * src0_d, // A scales
|
||||
__read_only image1d_buffer_t src1, // B
|
||||
ulong offset1, // offset to B (0)
|
||||
global float * dst, // C
|
||||
ulong offsetd, // offset to C
|
||||
int ne00, // K
|
||||
int ne01, // M
|
||||
int ne02, // 1
|
||||
int ne10, // K
|
||||
int ne12, // 1
|
||||
int ne0, // M
|
||||
int ne1, // N
|
||||
int r2, // 1
|
||||
int r3)
|
||||
{
|
||||
uint groupId = get_local_id(1);
|
||||
uint gid = get_global_id(0);
|
||||
ushort slid = get_sub_group_local_id();
|
||||
|
||||
uint K = ne00;
|
||||
uint M = ne01;
|
||||
|
||||
uint LINE_STRIDE_A = M;
|
||||
uint BLOCK_STRIDE_A = 8 * M; // 32 / 4 = 8
|
||||
|
||||
__private uint8 regA;
|
||||
__private half regS;
|
||||
__private float8 regB;
|
||||
|
||||
__private float totalSum = (float)(0.0f);
|
||||
|
||||
// loop along K in block granularity, skip 4 blocks every iter
|
||||
#pragma unroll 1 /* tell compiler not to unroll */
|
||||
for (uint k = groupId; k < (K / QK8_0); k += N_SIMDGROUP) {
|
||||
regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of one rows
|
||||
// first 4 fibers in each wave load 8 B values to its private scope
|
||||
if (slid < 4) {
|
||||
regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
|
||||
regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
|
||||
}
|
||||
|
||||
// load weights for one block in consecutive rows
|
||||
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
|
||||
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
|
||||
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
|
||||
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
|
||||
regA.s4 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
|
||||
regA.s5 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
|
||||
regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
|
||||
regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
|
||||
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB);
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #wave=4
|
||||
__local float reduceLM[SIMDGROUP_WIDTH * 3];
|
||||
if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;
|
||||
if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;
|
||||
if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
|
||||
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
|
||||
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
|
||||
|
||||
// 1 outputs per fiber in wave 0
|
||||
if (groupId == 0) {
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
dst[gid] = totalSum;
|
||||
}
|
||||
}
|
||||
129
ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl
Normal file
129
ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl
Normal file
@@ -0,0 +1,129 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
|
||||
#ifdef cl_qcom_reqd_sub_group_size
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_128
|
||||
#endif
|
||||
|
||||
kernel void kernel_mul_mm_q8_0_f32_8x4(
|
||||
global const uint * src0_q,
|
||||
global const half * src0_d,
|
||||
__read_only image1d_buffer_t src1,
|
||||
global float * dst,
|
||||
int k,
|
||||
int m,
|
||||
int n,
|
||||
int n_no_padding,
|
||||
ulong offsetd
|
||||
) {
|
||||
|
||||
int m_4 = m >> 2;
|
||||
int n_4 = n >> 2;
|
||||
|
||||
int gy = get_global_id(0);
|
||||
int gx = get_global_id(1);
|
||||
int gx_2 = gx << 2;
|
||||
dst = (global float *)((global char*)dst + offsetd);
|
||||
|
||||
|
||||
half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0;
|
||||
half8 B;
|
||||
half4 deq;
|
||||
|
||||
__global const uint* wptr = src0_q + gx_2;
|
||||
__global const half* sptr = src0_d + gx_2;
|
||||
|
||||
for (int i = 0; i < k; i += 4) {
|
||||
uint4 pack4 = vload4(0, wptr + (i / 4) * m);
|
||||
half4 scale = vload4(0, sptr + (i / 32) * m);
|
||||
|
||||
char4 p0 = as_char4(pack4.s0);
|
||||
char4 p1 = as_char4(pack4.s1);
|
||||
char4 p2 = as_char4(pack4.s2);
|
||||
char4 p3 = as_char4(pack4.s3);
|
||||
|
||||
// ------------------- j = 0 (k = i+0) -------------------
|
||||
B.s0123 = read_imageh(src1, gy * 2 + (i + 0) * n_4);
|
||||
B.s4567 = read_imageh(src1, gy * 2 + (i + 0) * n_4 + 1);
|
||||
|
||||
half4 wj0 = convert_half4((char4)(p0.s0, p1.s0, p2.s0, p3.s0)) * scale;
|
||||
|
||||
c0 += B * wj0.s0;
|
||||
c1 += B * wj0.s1;
|
||||
c2 += B * wj0.s2;
|
||||
c3 += B * wj0.s3;
|
||||
|
||||
// ------------------- j = 1 (k = i+1) -------------------
|
||||
B.s0123 = read_imageh(src1, gy * 2 + (i + 1) * n_4);
|
||||
B.s4567 = read_imageh(src1, gy * 2 + (i + 1) * n_4 + 1);
|
||||
|
||||
half4 wj1 = convert_half4((char4)(p0.s1, p1.s1, p2.s1, p3.s1)) * scale;
|
||||
|
||||
c0 += B * wj1.s0;
|
||||
c1 += B * wj1.s1;
|
||||
c2 += B * wj1.s2;
|
||||
c3 += B * wj1.s3;
|
||||
|
||||
// ------------------- j = 2 (k = i+2) -------------------
|
||||
B.s0123 = read_imageh(src1, gy * 2 + (i + 2) * n_4);
|
||||
B.s4567 = read_imageh(src1, gy * 2 + (i + 2) * n_4 + 1);
|
||||
|
||||
half4 wj2 = convert_half4((char4)(p0.s2, p1.s2, p2.s2, p3.s2)) * scale;
|
||||
|
||||
c0 += B * wj2.s0;
|
||||
c1 += B * wj2.s1;
|
||||
c2 += B * wj2.s2;
|
||||
c3 += B * wj2.s3;
|
||||
|
||||
// ------------------- j = 3 (k = i+3) -------------------
|
||||
B.s0123 = read_imageh(src1, gy * 2 + (i + 3) * n_4);
|
||||
B.s4567 = read_imageh(src1, gy * 2 + (i + 3) * n_4 + 1);
|
||||
|
||||
half4 wj3 = convert_half4((char4)(p0.s3, p1.s3, p2.s3, p3.s3)) * scale;
|
||||
|
||||
c0 += B * wj3.s0;
|
||||
c1 += B * wj3.s1;
|
||||
c2 += B * wj3.s2;
|
||||
c3 += B * wj3.s3;
|
||||
}
|
||||
|
||||
int idx = (gy << 3) * m + (gx << 2);
|
||||
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,6 @@
|
||||
|
||||
#include <sycl/sycl.hpp>
|
||||
#include <sycl/half_type.hpp>
|
||||
#include <syclcompat/math.hpp>
|
||||
#include <map>
|
||||
|
||||
#ifdef GGML_SYCL_USE_INTEL_ONEMKL
|
||||
|
||||
@@ -123,6 +123,15 @@ static __dpct_inline__ T op_log(T x) {
|
||||
return sycl::log(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_softplus(T x) {
|
||||
const float xf = (float) x;
|
||||
const float ax = sycl::fabs(xf);
|
||||
const float m = sycl::fmax(xf, 0.0f);
|
||||
const float y = m + sycl::log1p(sycl::exp(-ax));
|
||||
return (T) y;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_neg(T x) {
|
||||
return -x;
|
||||
@@ -695,6 +704,12 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_softplus(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_neg(x);
|
||||
@@ -1101,6 +1116,11 @@ void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_log(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_softplus(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_neg(ctx, dst);
|
||||
|
||||
@@ -61,6 +61,8 @@ void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -2263,6 +2263,65 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_ten
|
||||
diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
|
||||
}
|
||||
|
||||
static void tri_f32_sycl(
|
||||
const float * src,
|
||||
float * dst,
|
||||
const int64_t ne0,
|
||||
const int64_t ne1,
|
||||
const int64_t ne2,
|
||||
const int64_t ne3,
|
||||
const ggml_tri_type ttype,
|
||||
dpct::queue_ptr main_stream
|
||||
) {
|
||||
const size_t total = (size_t) ne0 * (size_t) ne1 * (size_t) ne2 * (size_t) ne3;
|
||||
|
||||
main_stream->parallel_for(sycl::range<1>(total), [=](sycl::id<1> tid) {
|
||||
const int64_t idx = (int64_t) tid[0];
|
||||
|
||||
const int64_t i0 = idx % ne0;
|
||||
const int64_t t1 = idx / ne0;
|
||||
const int64_t i1 = t1 % ne1;
|
||||
|
||||
bool keep = false;
|
||||
switch (ttype) {
|
||||
case GGML_TRI_TYPE_LOWER: keep = (i0 < i1); break;
|
||||
case GGML_TRI_TYPE_LOWER_DIAG: keep = (i0 <= i1); break;
|
||||
case GGML_TRI_TYPE_UPPER: keep = (i0 > i1); break;
|
||||
case GGML_TRI_TYPE_UPPER_DIAG: keep = (i0 >= i1); break;
|
||||
default: keep = false; break;
|
||||
}
|
||||
|
||||
dst[idx] = keep ? src[idx] : 0.0f;
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_sycl_op_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
GGML_ASSERT(src0);
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
const float * src0_dd = static_cast<const float *>(src0->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
const int64_t ne0 = src0->ne[0];
|
||||
const int64_t ne1 = src0->ne[1];
|
||||
const int64_t ne2 = src0->ne[2];
|
||||
const int64_t ne3 = src0->ne[3];
|
||||
|
||||
tri_f32_sycl(src0_dd, dst_dd, ne0, ne1, ne2, ne3, ttype, main_stream);
|
||||
}
|
||||
|
||||
|
||||
inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
@@ -3786,6 +3845,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
case GGML_UNARY_OP_EXP:
|
||||
ggml_sycl_exp(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
ggml_sycl_softplus(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_SGN:
|
||||
ggml_sycl_sgn(ctx, dst);
|
||||
break;
|
||||
@@ -3912,6 +3974,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
case GGML_OP_TRANSPOSE:
|
||||
GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__);
|
||||
break;
|
||||
case GGML_OP_TRI:
|
||||
ggml_sycl_op_tri(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
ggml_sycl_diag_mask_inf(ctx, dst);
|
||||
break;
|
||||
@@ -4404,6 +4469,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
return true;
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
@@ -4606,18 +4672,23 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
|
||||
#endif
|
||||
case GGML_OP_NORM:
|
||||
return true;
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_RMS_NORM:
|
||||
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
|
||||
return true;
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_SCALE:
|
||||
return true;
|
||||
case GGML_OP_CONT:
|
||||
return op->src[0]->type != GGML_TYPE_BF16;
|
||||
case GGML_OP_TRI:
|
||||
{
|
||||
const ggml_tensor * src0 = op->src[0];
|
||||
return src0 &&
|
||||
op->type == GGML_TYPE_F32 &&
|
||||
ggml_is_contiguous(src0);
|
||||
}
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
return true;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
|
||||
@@ -251,7 +251,6 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
|
||||
const float eps, queue_ptr stream, int device) {
|
||||
|
||||
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
if (ncols < 1024) {
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
@@ -334,7 +333,6 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
||||
|
||||
static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
||||
|
||||
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
|
||||
@@ -374,7 +372,6 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
|
||||
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
const int nrows, const float eps,
|
||||
queue_ptr stream, int device) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
||||
if (ncols < 1024) {
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
|
||||
70
ggml/src/ggml-virtgpu/CMakeLists.txt
Normal file
70
ggml/src/ggml-virtgpu/CMakeLists.txt
Normal file
@@ -0,0 +1,70 @@
|
||||
cmake_minimum_required(VERSION 3.19)
|
||||
cmake_policy(SET CMP0114 NEW)
|
||||
|
||||
include(ExternalProject)
|
||||
|
||||
message(STATUS "Including the VirtGPU/Virglrenderer API Remoting")
|
||||
|
||||
# Download venus_hw.h from virglrenderer repository
|
||||
ExternalProject_Add(
|
||||
venus_hw_header
|
||||
URL https://gitlab.freedesktop.org/virgl/virglrenderer/-/raw/virglrenderer-1.2.0/src/venus_hw.h
|
||||
DOWNLOAD_NO_EXTRACT YES
|
||||
DOWNLOAD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include
|
||||
DOWNLOAD_NAME venus_hw.h
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
INSTALL_COMMAND ""
|
||||
LOG_DOWNLOAD ON
|
||||
)
|
||||
|
||||
if (NOT GGML_VIRTGPU_BACKEND STREQUAL "ONLY")
|
||||
message(STATUS "Enable the VirtGPU/Virglrenderer API Remoting frontend library")
|
||||
|
||||
find_package(PkgConfig REQUIRED)
|
||||
pkg_check_modules(DRM REQUIRED libdrm)
|
||||
if (NOT GGML_BACKEND_DL)
|
||||
# cannot simply use USE_VIRTGPU, as in the 'else()' case the
|
||||
# frontend isn't compiled
|
||||
target_compile_definitions(ggml PUBLIC "GGML_USE_VIRTGPU_FRONTEND")
|
||||
endif()
|
||||
|
||||
ggml_add_backend_library(ggml-virtgpu
|
||||
ggml-backend-buffer.cpp
|
||||
ggml-backend.cpp
|
||||
ggml-backend-device.cpp
|
||||
ggml-backend-reg.cpp
|
||||
ggml-backend-buffer-type.cpp
|
||||
virtgpu-apir.h
|
||||
virtgpu-forward.gen.h
|
||||
virtgpu.cpp
|
||||
virtgpu-shm.cpp
|
||||
virtgpu-utils.cpp
|
||||
virtgpu-forward-device.cpp
|
||||
virtgpu-forward-buffer-type.cpp
|
||||
virtgpu-forward-buffer.cpp
|
||||
virtgpu-forward-backend.cpp
|
||||
virtgpu-forward-impl.h
|
||||
apir_cs_ggml-rpc-front.cpp
|
||||
../../include/ggml-virtgpu.h)
|
||||
|
||||
target_include_directories(ggml-virtgpu PUBLIC /usr/include/libdrm/)
|
||||
|
||||
target_link_libraries(ggml-virtgpu PUBLIC ${DRM_LIBRARIES})
|
||||
target_include_directories(ggml-virtgpu PUBLIC ${DRM_INCLUDE_DIRS})
|
||||
target_compile_options(ggml-virtgpu PUBLIC ${DRM_CFLAGS_OTHER})
|
||||
|
||||
target_include_directories(ggml-virtgpu PUBLIC ./include)
|
||||
target_include_directories(ggml-virtgpu PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
# Ensure venus_hw.h is downloaded before building ggml-virtgpu
|
||||
add_dependencies(ggml-virtgpu venus_hw_header)
|
||||
|
||||
target_compile_options(ggml-virtgpu PRIVATE -std=c++20)
|
||||
else()
|
||||
message(STATUS "Not building the VirtGPU/Virglrenderer API Remoting frontend library")
|
||||
endif()
|
||||
|
||||
if (NOT GGML_VIRTGPU_BACKEND STREQUAL "OFF")
|
||||
add_subdirectory("backend")
|
||||
endif()
|
||||
87
ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp
Normal file
87
ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp
Normal file
@@ -0,0 +1,87 @@
|
||||
#include "backend/shared/apir_cs_rpc.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-remoting.h"
|
||||
|
||||
#include <cinttypes>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
apir_rpc_tensor apir_serialize_tensor(const ggml_tensor * tensor) {
|
||||
apir_rpc_tensor result;
|
||||
result.id = reinterpret_cast<uint64_t>(tensor);
|
||||
result.type = tensor->type;
|
||||
if (tensor->buffer) {
|
||||
ggml_backend_buffer_t buffer = tensor->buffer;
|
||||
|
||||
result.buffer = BUFFER_TO_HOST_HANDLE(buffer);
|
||||
} else {
|
||||
result.buffer = 0;
|
||||
}
|
||||
for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
result.ne[i] = tensor->ne[i];
|
||||
result.nb[i] = tensor->nb[i];
|
||||
}
|
||||
result.op = tensor->op;
|
||||
for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
|
||||
result.op_params[i] = tensor->op_params[i];
|
||||
}
|
||||
result.flags = tensor->flags;
|
||||
for (uint32_t i = 0; i < GGML_MAX_SRC; i++) {
|
||||
result.src[i] = reinterpret_cast<uint64_t>(tensor->src[i]);
|
||||
}
|
||||
result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
|
||||
result.view_offs = tensor->view_offs;
|
||||
result.data = reinterpret_cast<uint64_t>(tensor->data);
|
||||
if (tensor->data) {
|
||||
if (!tensor->buffer) {
|
||||
GGML_ABORT("tensor has data but not buffer");
|
||||
}
|
||||
// tensor->data is serialized as an offset to the buffer base address
|
||||
result.data -= reinterpret_cast<uint64_t>(BUFFER_TO_GGML_CONTEXT(tensor->buffer)->base);
|
||||
}
|
||||
snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
|
||||
return result;
|
||||
}
|
||||
|
||||
void apir_add_tensor(ggml_tensor * tensor,
|
||||
std::vector<apir_rpc_tensor> & tensors,
|
||||
std::unordered_set<ggml_tensor *> & visited) {
|
||||
if (tensor == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (visited.find(tensor) != visited.end()) {
|
||||
return;
|
||||
}
|
||||
visited.insert(tensor);
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
apir_add_tensor(tensor->src[i], tensors, visited);
|
||||
}
|
||||
apir_add_tensor(tensor->view_src, tensors, visited);
|
||||
tensors.push_back(apir_serialize_tensor(tensor));
|
||||
}
|
||||
|
||||
void apir_serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
|
||||
uint32_t n_nodes = cgraph->n_nodes;
|
||||
std::vector<apir_rpc_tensor> tensors;
|
||||
std::unordered_set<ggml_tensor *> visited;
|
||||
for (uint32_t i = 0; i < n_nodes; i++) {
|
||||
apir_add_tensor(cgraph->nodes[i], tensors, visited);
|
||||
}
|
||||
// serialization format:
|
||||
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(apir_rpc_tensor)) |
|
||||
uint32_t n_tensors = tensors.size();
|
||||
int output_size =
|
||||
sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(apir_rpc_tensor);
|
||||
output.resize(output_size, 0);
|
||||
memcpy(output.data(), &n_nodes, sizeof(n_nodes));
|
||||
for (uint32_t i = 0; i < n_nodes; i++) {
|
||||
memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
|
||||
}
|
||||
uint32_t * out_ntensors = (uint32_t *) (output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
|
||||
*out_ntensors = n_tensors;
|
||||
apir_rpc_tensor * out_tensors =
|
||||
(apir_rpc_tensor *) (output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));
|
||||
memcpy(out_tensors, tensors.data(), n_tensors * sizeof(apir_rpc_tensor));
|
||||
}
|
||||
21
ggml/src/ggml-virtgpu/backend/CMakeLists.txt
Normal file
21
ggml/src/ggml-virtgpu/backend/CMakeLists.txt
Normal file
@@ -0,0 +1,21 @@
|
||||
cmake_minimum_required(VERSION 3.19)
|
||||
cmake_policy(SET CMP0114 NEW)
|
||||
|
||||
message(STATUS "Enable the VirtGPU/Virglrenderer backend library")
|
||||
|
||||
ggml_add_backend_library(ggml-virtgpu-backend
|
||||
backend.cpp
|
||||
backend-dispatched.cpp
|
||||
backend-dispatched-backend.cpp
|
||||
backend-dispatched-device.cpp
|
||||
backend-dispatched-buffer.cpp
|
||||
backend-dispatched-buffer-type.cpp
|
||||
shared/api_remoting.h
|
||||
shared/apir_backend.h
|
||||
shared/apir_cs.h
|
||||
apir_cs_ggml-rpc-back.cpp)
|
||||
|
||||
target_compile_options(ggml-virtgpu-backend PRIVATE -std=c++20)
|
||||
|
||||
# Add include directory for ggml-backend-impl.h and other core headers
|
||||
target_include_directories(ggml-virtgpu-backend PRIVATE ../..)
|
||||
115
ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp
Normal file
115
ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp
Normal file
@@ -0,0 +1,115 @@
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "shared/apir_cs_rpc.h"
|
||||
|
||||
#include <cinttypes>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
std::unordered_set<ggml_backend_buffer_t> backend_buffers;
|
||||
|
||||
void apir_track_backend_buffer(ggml_backend_buffer_t buffer) {
|
||||
backend_buffers.insert(buffer);
|
||||
}
|
||||
|
||||
bool apir_untrack_backend_buffer(ggml_backend_buffer_t buffer) {
|
||||
auto it = backend_buffers.find(buffer);
|
||||
if (it == backend_buffers.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
backend_buffers.erase(it);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unordered_set<ggml_backend_buffer_t> apir_get_track_backend_buffers() {
|
||||
return backend_buffers;
|
||||
}
|
||||
|
||||
ggml_tensor * apir_deserialize_tensor(ggml_context * ctx, const apir_rpc_tensor * tensor) {
|
||||
ggml_tensor * result =
|
||||
ggml_new_tensor_4d(ctx, (ggml_type) tensor->type, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
||||
for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
result->nb[i] = tensor->nb[i];
|
||||
}
|
||||
result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
|
||||
if (result->buffer && backend_buffers.find(result->buffer) == backend_buffers.end()) {
|
||||
printf("WARNING: HOST BUFFER NOT FOUND | %p\n", (void *) result->buffer);
|
||||
result->buffer = nullptr;
|
||||
}
|
||||
|
||||
uint64_t tensor_data = tensor->data;
|
||||
if (result->buffer) {
|
||||
// require that the tensor data does not go beyond the buffer end
|
||||
uint64_t tensor_size = (uint64_t) ggml_nbytes(result);
|
||||
uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);
|
||||
uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer);
|
||||
|
||||
// tensor->data is serialized as an offset to the buffer base address
|
||||
tensor_data += buffer_start;
|
||||
|
||||
GGML_ASSERT(tensor_data + tensor_size >= tensor_data); // check for overflow
|
||||
GGML_ASSERT(tensor_data >= buffer_start && tensor_data + tensor_size <= buffer_start + buffer_size);
|
||||
}
|
||||
|
||||
result->op = (ggml_op) tensor->op;
|
||||
for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
|
||||
result->op_params[i] = tensor->op_params[i];
|
||||
}
|
||||
result->flags = tensor->flags;
|
||||
result->data = reinterpret_cast<void *>(tensor_data);
|
||||
ggml_set_name(result, tensor->name);
|
||||
return result;
|
||||
}
|
||||
|
||||
ggml_tensor * apir_create_node(uint64_t id,
|
||||
ggml_context * ctx,
|
||||
const std::unordered_map<uint64_t, const apir_rpc_tensor *> & tensor_ptrs,
|
||||
std::unordered_map<uint64_t, ggml_tensor *> & tensor_map) {
|
||||
if (id == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
if (tensor_map.find(id) != tensor_map.end()) {
|
||||
return tensor_map[id];
|
||||
}
|
||||
const apir_rpc_tensor * tensor = tensor_ptrs.at(id);
|
||||
ggml_tensor * result = apir_deserialize_tensor(ctx, tensor);
|
||||
if (result == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
tensor_map[id] = result;
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
result->src[i] = apir_create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
|
||||
}
|
||||
result->view_src = apir_create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
|
||||
result->view_offs = tensor->view_offs;
|
||||
return result;
|
||||
}
|
||||
|
||||
ggml_cgraph * apir_deserialize_graph(uint32_t n_nodes,
|
||||
uint32_t n_tensors,
|
||||
const apir_rpc_tensor * tensors,
|
||||
const uint64_t * nodes) {
|
||||
size_t buf_size = ggml_tensor_overhead() * (n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/buf_size,
|
||||
/*.mem_buffer =*/NULL,
|
||||
/*.no_alloc =*/true,
|
||||
};
|
||||
ggml_context * ctx = ggml_init(params);
|
||||
ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
|
||||
graph->n_nodes = n_nodes;
|
||||
std::unordered_map<uint64_t, const apir_rpc_tensor *> tensor_ptrs;
|
||||
for (uint32_t i = 0; i < n_tensors; i++) {
|
||||
tensor_ptrs[tensors[i].id] = &tensors[i];
|
||||
}
|
||||
std::unordered_map<uint64_t, ggml_tensor *> tensor_map;
|
||||
for (uint32_t i = 0; i < n_nodes; i++) {
|
||||
int64_t id;
|
||||
memcpy(&id, &nodes[i], sizeof(id));
|
||||
graph->nodes[i] = apir_create_node(id, ctx, tensor_ptrs, tensor_map);
|
||||
}
|
||||
|
||||
return graph;
|
||||
}
|
||||
13
ggml/src/ggml-virtgpu/backend/backend-convert.h
Normal file
13
ggml/src/ggml-virtgpu/backend/backend-convert.h
Normal file
@@ -0,0 +1,13 @@
|
||||
#include "shared/apir_backend.h"
|
||||
|
||||
#define BUFFER_TO_HOST_HANDLE(name) ggml_buffer_to_apir_handle(name)
|
||||
|
||||
static inline apir_buffer_host_handle_t ggml_buffer_to_apir_handle(ggml_backend_buffer_t buffer) {
|
||||
// in the backend, the buffer handle is the buffer pointer
|
||||
return (apir_buffer_host_handle_t) buffer;
|
||||
}
|
||||
|
||||
static inline apir_buffer_type_host_handle_t ggml_buffer_type_to_apir_handle(ggml_backend_buffer_type_t buft) {
|
||||
// in the backend, the buffer handle is the buffer pointer
|
||||
return (apir_buffer_type_host_handle_t) buft;
|
||||
}
|
||||
65
ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp
Normal file
65
ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp
Normal file
@@ -0,0 +1,65 @@
|
||||
#include "backend-dispatched.h"
|
||||
#include "backend-virgl-apir.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "shared/apir_backend.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(enc);
|
||||
|
||||
static bool async_backend_initialized = false;
|
||||
static bool async_backend;
|
||||
|
||||
if (!async_backend_initialized) {
|
||||
ggml_backend_dev_props props;
|
||||
|
||||
dev->iface.get_props(dev, &props);
|
||||
async_backend = props.caps.async;
|
||||
async_backend_initialized = true;
|
||||
}
|
||||
|
||||
uint32_t shmem_res_id;
|
||||
apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id);
|
||||
|
||||
const void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
|
||||
if (!shmem_data) {
|
||||
GGML_LOG_ERROR("Couldn't get the shmem addr from virgl\n");
|
||||
apir_decoder_set_fatal(dec);
|
||||
return 1;
|
||||
}
|
||||
size_t cgraph_size;
|
||||
apir_decode_size_t(dec, &cgraph_size);
|
||||
|
||||
apir_decoder secondary_dec = apir_new_decoder((const char *) shmem_data, cgraph_size);
|
||||
|
||||
ggml_cgraph * cgraph = apir_decode_ggml_cgraph(&secondary_dec, cgraph_size);
|
||||
|
||||
ggml_status status;
|
||||
#if APIR_BACKEND_CHECK_SUPPORTS_OP == 1
|
||||
for (int idx = 0; idx < cgraph->n_nodes; idx++) {
|
||||
ggml_tensor * op = ggml_graph_node(cgraph, idx);
|
||||
if (dev->iface.supports_op(dev, op)) {
|
||||
continue;
|
||||
}
|
||||
GGML_LOG_ERROR("Graph node %d (%s) not supported by the backend\n", idx, ggml_op_desc(op));
|
||||
|
||||
status = GGML_STATUS_ABORTED;
|
||||
apir_encode_ggml_status(enc, &status);
|
||||
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
status = bck->iface.graph_compute(bck, cgraph);
|
||||
|
||||
if (async_backend) {
|
||||
bck->iface.synchronize(bck);
|
||||
}
|
||||
|
||||
apir_encode_ggml_status(enc, &status);
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
#include "backend-dispatched.h"
|
||||
#include "backend-virgl-apir.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
uint32_t backend_buffer_type_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
ggml_backend_buffer_type_t buft;
|
||||
buft = apir_decode_ggml_buffer_type(dec);
|
||||
|
||||
const char * string = buft->iface.get_name(buft);
|
||||
|
||||
const size_t string_size = strlen(string) + 1;
|
||||
apir_encode_array_size(enc, string_size);
|
||||
apir_encode_char_array(enc, string, string_size);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t backend_buffer_type_get_alignment(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
ggml_backend_buffer_type_t buft;
|
||||
buft = apir_decode_ggml_buffer_type(dec);
|
||||
|
||||
size_t value = buft->iface.get_alignment(buft);
|
||||
apir_encode_size_t(enc, &value);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t backend_buffer_type_get_max_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
ggml_backend_buffer_type_t buft;
|
||||
buft = apir_decode_ggml_buffer_type(dec);
|
||||
|
||||
size_t value = buft->iface.get_max_size(buft);
|
||||
apir_encode_size_t(enc, &value);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t backend_buffer_type_is_host(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
ggml_backend_buffer_type_t buft;
|
||||
buft = apir_decode_ggml_buffer_type(dec);
|
||||
|
||||
bool is_host = buft->iface.is_host(buft);
|
||||
apir_encode_bool_t(enc, &is_host);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t backend_buffer_type_alloc_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
ggml_backend_buffer_type_t buft;
|
||||
buft = apir_decode_ggml_buffer_type(dec);
|
||||
|
||||
size_t size;
|
||||
apir_decode_size_t(dec, &size);
|
||||
|
||||
ggml_backend_buffer_t buffer;
|
||||
|
||||
buffer = buft->iface.alloc_buffer(buft, size);
|
||||
|
||||
apir_encode_ggml_buffer(enc, buffer);
|
||||
|
||||
if (buffer) {
|
||||
apir_track_backend_buffer(buffer);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t backend_buffer_type_get_alloc_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
ggml_backend_buffer_type_t buft;
|
||||
buft = apir_decode_ggml_buffer_type(dec);
|
||||
|
||||
const ggml_tensor * op = apir_decode_ggml_tensor_inplace(dec);
|
||||
|
||||
size_t value = buft->iface.get_alloc_size(buft, op);
|
||||
|
||||
apir_encode_size_t(enc, &value);
|
||||
|
||||
return 0;
|
||||
}
|
||||
131
ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp
Normal file
131
ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp
Normal file
@@ -0,0 +1,131 @@
|
||||
#include "backend-dispatched.h"
|
||||
#include "backend-virgl-apir.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
uint32_t backend_buffer_get_base(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
ggml_backend_buffer_t buffer;
|
||||
buffer = apir_decode_ggml_buffer(dec);
|
||||
|
||||
uintptr_t base = (uintptr_t) buffer->iface.get_base(buffer);
|
||||
apir_encode_uintptr_t(enc, &base);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(enc);
|
||||
|
||||
ggml_backend_buffer_t buffer;
|
||||
buffer = apir_decode_ggml_buffer(dec);
|
||||
|
||||
ggml_tensor * tensor;
|
||||
// safe to remove the const qualifier here
|
||||
tensor = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec);
|
||||
|
||||
uint32_t shmem_res_id;
|
||||
apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id);
|
||||
|
||||
size_t offset;
|
||||
apir_decode_size_t(dec, &offset);
|
||||
|
||||
size_t size;
|
||||
apir_decode_size_t(dec, &size);
|
||||
|
||||
void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
|
||||
|
||||
if (!shmem_data) {
|
||||
GGML_LOG_ERROR("Couldn't get the shmem addr from virgl\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
buffer->iface.set_tensor(buffer, tensor, shmem_data, offset, size);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(enc);
|
||||
|
||||
ggml_backend_buffer_t buffer;
|
||||
buffer = apir_decode_ggml_buffer(dec);
|
||||
|
||||
const ggml_tensor * tensor;
|
||||
// safe to remove the const qualifier here
|
||||
tensor = apir_decode_ggml_tensor(dec);
|
||||
|
||||
uint32_t shmem_res_id;
|
||||
apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id);
|
||||
|
||||
size_t offset;
|
||||
apir_decode_size_t(dec, &offset);
|
||||
|
||||
size_t size;
|
||||
apir_decode_size_t(dec, &size);
|
||||
|
||||
void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
|
||||
if (!shmem_data) {
|
||||
GGML_LOG_ERROR("Couldn't get the shmem addr from virgl\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
buffer->iface.get_tensor(buffer, tensor, shmem_data, offset, size);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t backend_buffer_cpy_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
|
||||
ggml_backend_buffer_t buffer;
|
||||
buffer = apir_decode_ggml_buffer(dec);
|
||||
|
||||
const ggml_tensor * src;
|
||||
// safe to remove the const qualifier here
|
||||
src = apir_decode_ggml_tensor(dec);
|
||||
ggml_tensor * dst = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec);
|
||||
|
||||
bool ret = buffer->iface.cpy_tensor(buffer, src, (ggml_tensor *) dst);
|
||||
|
||||
apir_encode_bool_t(enc, &ret);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t backend_buffer_clear(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(enc);
|
||||
|
||||
ggml_backend_buffer_t buffer;
|
||||
buffer = apir_decode_ggml_buffer(dec);
|
||||
|
||||
uint8_t value;
|
||||
apir_decode_uint8_t(dec, &value);
|
||||
|
||||
buffer->iface.clear(buffer, value);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(enc);
|
||||
|
||||
ggml_backend_buffer_t buffer;
|
||||
buffer = apir_decode_ggml_buffer(dec);
|
||||
|
||||
if (!apir_untrack_backend_buffer(buffer)) {
|
||||
GGML_LOG_WARN("%s: unknown buffer %p\n", __func__, (void *) buffer);
|
||||
return 1;
|
||||
}
|
||||
|
||||
buffer->iface.free_buffer(buffer);
|
||||
|
||||
return 0;
|
||||
}
|
||||
148
ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp
Normal file
148
ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp
Normal file
@@ -0,0 +1,148 @@
|
||||
#include "backend-dispatched.h"
|
||||
#include "backend-virgl-apir.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
uint32_t backend_device_get_device_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(dec);
|
||||
|
||||
int32_t dev_count = reg->iface.get_device_count(reg);
|
||||
apir_encode_int32_t(enc, &dev_count);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t backend_device_get_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(dec);
|
||||
|
||||
int32_t dev_count = reg->iface.get_device_count(reg);
|
||||
apir_encode_int32_t(enc, &dev_count);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t backend_device_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(dec);
|
||||
|
||||
const char * string = dev->iface.get_name(dev);
|
||||
|
||||
const size_t string_size = strlen(string) + 1;
|
||||
apir_encode_array_size(enc, string_size);
|
||||
apir_encode_char_array(enc, string, string_size);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t backend_device_get_description(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(dec);
|
||||
|
||||
const char * string = dev->iface.get_description(dev);
|
||||
|
||||
const size_t string_size = strlen(string) + 1;
|
||||
apir_encode_array_size(enc, string_size);
|
||||
apir_encode_char_array(enc, string, string_size);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t backend_device_get_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(dec);
|
||||
|
||||
uint32_t type = dev->iface.get_type(dev);
|
||||
apir_encode_uint32_t(enc, &type);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t backend_device_get_memory(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(dec);
|
||||
|
||||
size_t free, total;
|
||||
dev->iface.get_memory(dev, &free, &total);
|
||||
|
||||
apir_encode_size_t(enc, &free);
|
||||
apir_encode_size_t(enc, &total);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t backend_device_supports_op(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
|
||||
const ggml_tensor * op = apir_decode_ggml_tensor_inplace(dec);
|
||||
|
||||
bool supports_op = dev->iface.supports_op(dev, op);
|
||||
|
||||
apir_encode_bool_t(enc, &supports_op);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t backend_device_get_buffer_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(dec);
|
||||
|
||||
ggml_backend_buffer_type_t bufft = dev->iface.get_buffer_type(dev);
|
||||
|
||||
apir_encode_ggml_buffer_type(enc, bufft);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t backend_device_get_props(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(dec);
|
||||
|
||||
ggml_backend_dev_props props;
|
||||
dev->iface.get_props(dev, &props);
|
||||
|
||||
apir_encode_bool_t(enc, &props.caps.async);
|
||||
apir_encode_bool_t(enc, &props.caps.host_buffer);
|
||||
apir_encode_bool_t(enc, &props.caps.buffer_from_host_ptr);
|
||||
apir_encode_bool_t(enc, &props.caps.events);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t backend_device_buffer_from_ptr(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(dec);
|
||||
|
||||
uint32_t shmem_res_id;
|
||||
apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id);
|
||||
|
||||
void * shmem_ptr = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
|
||||
if (!shmem_ptr) {
|
||||
GGML_LOG_ERROR("Couldn't get the shmem addr from virgl\n");
|
||||
apir_decoder_set_fatal(dec);
|
||||
return 1;
|
||||
}
|
||||
|
||||
size_t size;
|
||||
apir_decode_size_t(dec, &size);
|
||||
size_t max_tensor_size;
|
||||
apir_decode_size_t(dec, &max_tensor_size);
|
||||
|
||||
ggml_backend_buffer_t buffer;
|
||||
buffer = dev->iface.buffer_from_host_ptr(dev, shmem_ptr, size, max_tensor_size);
|
||||
|
||||
apir_encode_ggml_buffer(enc, buffer);
|
||||
apir_encode_ggml_buffer_type(enc, buffer->buft);
|
||||
|
||||
if (buffer) {
|
||||
apir_track_backend_buffer(buffer);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
46
ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp
Normal file
46
ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp
Normal file
@@ -0,0 +1,46 @@
|
||||
#include "backend-dispatched.h"
|
||||
#include "backend-virgl-apir.h"
|
||||
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
ggml_backend_reg_t reg = NULL;
|
||||
ggml_backend_dev_t dev = NULL;
|
||||
ggml_backend_t bck = NULL;
|
||||
|
||||
uint64_t timer_start = 0;
|
||||
uint64_t timer_total = 0;
|
||||
uint64_t timer_count = 0;
|
||||
|
||||
uint32_t backend_dispatch_initialize(void * ggml_backend_reg_fct_p) {
|
||||
if (reg != NULL) {
|
||||
GGML_LOG_WARN("%s: already initialized\n", __func__);
|
||||
return APIR_BACKEND_INITIALIZE_ALREADY_INITED;
|
||||
}
|
||||
ggml_backend_reg_t (*ggml_backend_reg_fct)(void) = (ggml_backend_reg_t (*)()) ggml_backend_reg_fct_p;
|
||||
|
||||
reg = ggml_backend_reg_fct();
|
||||
if (reg == NULL) {
|
||||
GGML_LOG_ERROR("%s: backend registration failed\n", __func__);
|
||||
return APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED;
|
||||
}
|
||||
|
||||
if (!reg->iface.get_device_count(reg)) {
|
||||
GGML_LOG_ERROR("%s: backend initialization failed: no device found\n", __func__);
|
||||
return APIR_BACKEND_INITIALIZE_NO_DEVICE;
|
||||
}
|
||||
|
||||
dev = reg->iface.get_device(reg, 0);
|
||||
|
||||
if (!dev) {
|
||||
GGML_LOG_ERROR("%s: backend initialization failed: no device received\n", __func__);
|
||||
return APIR_BACKEND_INITIALIZE_NO_DEVICE;
|
||||
}
|
||||
|
||||
bck = dev->iface.init_backend(dev, NULL);
|
||||
|
||||
return APIR_BACKEND_INITIALIZE_SUCCESS;
|
||||
}
|
||||
130
ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h
Normal file
130
ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h
Normal file
@@ -0,0 +1,130 @@
|
||||
#pragma once
|
||||
|
||||
/* device */
|
||||
uint32_t backend_device_get_device_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
uint32_t backend_device_get_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
uint32_t backend_device_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
uint32_t backend_device_get_description(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
uint32_t backend_device_get_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
uint32_t backend_device_get_memory(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
uint32_t backend_device_supports_op(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
uint32_t backend_device_get_buffer_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
uint32_t backend_device_get_props(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
uint32_t backend_device_buffer_from_ptr(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
|
||||
/* buffer-type */
|
||||
uint32_t backend_buffer_type_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
uint32_t backend_buffer_type_get_alignment(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
uint32_t backend_buffer_type_get_max_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
uint32_t backend_buffer_type_is_host(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
uint32_t backend_buffer_type_alloc_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
uint32_t backend_buffer_type_get_alloc_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
|
||||
/* buffer */
|
||||
uint32_t backend_buffer_get_base(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
uint32_t backend_buffer_cpy_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
uint32_t backend_buffer_clear(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
|
||||
/* backend */
|
||||
uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
|
||||
static inline const char * backend_dispatch_command_name(ApirBackendCommandType type) {
|
||||
switch (type) {
|
||||
/* device */
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT:
|
||||
return "backend_device_get_device_count";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_COUNT:
|
||||
return "backend_device_get_count";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_NAME:
|
||||
return "backend_device_get_name";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION:
|
||||
return "backend_device_get_description";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_TYPE:
|
||||
return "backend_device_get_type";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_MEMORY:
|
||||
return "backend_device_get_memory";
|
||||
case APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP:
|
||||
return "backend_device_supports_op";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE:
|
||||
return "backend_device_get_buffer_type";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_PROPS:
|
||||
return "backend_device_get_props";
|
||||
case APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR:
|
||||
return "backend_device_buffer_from_ptr";
|
||||
/* buffer-type */
|
||||
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME:
|
||||
return "backend_buffer_type_get_name";
|
||||
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT:
|
||||
return "backend_buffer_type_get_alignment";
|
||||
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE:
|
||||
return "backend_buffer_type_get_max_size";
|
||||
case APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST:
|
||||
return "backend_buffer_type_is_host";
|
||||
case APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER:
|
||||
return "backend_buffer_type_alloc_buffer";
|
||||
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE:
|
||||
return "backend_buffer_type_get_alloc_size";
|
||||
/* buffer */
|
||||
case APIR_COMMAND_TYPE_BUFFER_GET_BASE:
|
||||
return "backend_buffer_get_base";
|
||||
case APIR_COMMAND_TYPE_BUFFER_SET_TENSOR:
|
||||
return "backend_buffer_set_tensor";
|
||||
case APIR_COMMAND_TYPE_BUFFER_GET_TENSOR:
|
||||
return "backend_buffer_get_tensor";
|
||||
case APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR:
|
||||
return "backend_buffer_cpy_tensor";
|
||||
case APIR_COMMAND_TYPE_BUFFER_CLEAR:
|
||||
return "backend_buffer_clear";
|
||||
case APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER:
|
||||
return "backend_buffer_free_buffer";
|
||||
/* backend */
|
||||
case APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE:
|
||||
return "backend_backend_graph_compute";
|
||||
|
||||
default:
|
||||
return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {
|
||||
|
||||
/* device */
|
||||
|
||||
/* APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT = */ backend_device_get_device_count,
|
||||
/* APIR_COMMAND_TYPE_DEVICE_GET_COUNT = */ backend_device_get_count,
|
||||
/* APIR_COMMAND_TYPE_DEVICE_GET_NAME = */ backend_device_get_name,
|
||||
/* APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION = */ backend_device_get_description,
|
||||
/* APIR_COMMAND_TYPE_DEVICE_GET_TYPE = */ backend_device_get_type,
|
||||
/* APIR_COMMAND_TYPE_DEVICE_GET_MEMORY = */ backend_device_get_memory,
|
||||
/* APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP = */ backend_device_supports_op,
|
||||
/* APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE = */ backend_device_get_buffer_type,
|
||||
/* APIR_COMMAND_TYPE_DEVICE_GET_PROPS = */ backend_device_get_props,
|
||||
/* APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR = */ backend_device_buffer_from_ptr,
|
||||
|
||||
/* buffer-type */
|
||||
|
||||
/* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME = */ backend_buffer_type_get_name,
|
||||
/* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT = */ backend_buffer_type_get_alignment,
|
||||
/* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE = */ backend_buffer_type_get_max_size,
|
||||
/* APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST = */ backend_buffer_type_is_host,
|
||||
/* APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER = */ backend_buffer_type_alloc_buffer,
|
||||
/* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE = */ backend_buffer_type_get_alloc_size,
|
||||
|
||||
/* buffer */
|
||||
|
||||
/* APIR_COMMAND_TYPE_BUFFER_GET_BASE = */ backend_buffer_get_base,
|
||||
/* APIR_COMMAND_TYPE_BUFFER_SET_TENSOR = */ backend_buffer_set_tensor,
|
||||
/* APIR_COMMAND_TYPE_BUFFER_GET_TENSOR = */ backend_buffer_get_tensor,
|
||||
/* APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR = */ backend_buffer_cpy_tensor,
|
||||
/* APIR_COMMAND_TYPE_BUFFER_CLEAR = */ backend_buffer_clear,
|
||||
/* APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER = */ backend_buffer_free_buffer,
|
||||
|
||||
/* backend */
|
||||
|
||||
/* APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE = */ backend_backend_graph_compute,
|
||||
};
|
||||
}
|
||||
23
ggml/src/ggml-virtgpu/backend/backend-dispatched.h
Normal file
23
ggml/src/ggml-virtgpu/backend/backend-dispatched.h
Normal file
@@ -0,0 +1,23 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstddef>
|
||||
|
||||
#include <ggml-backend.h>
|
||||
|
||||
#include "backend-convert.h"
|
||||
#include "backend-virgl-apir.h"
|
||||
#include "shared/apir_backend.h"
|
||||
#include "shared/apir_cs.h"
|
||||
#include "shared/apir_cs_ggml.h"
|
||||
|
||||
struct virgl_apir_context {
|
||||
uint32_t ctx_id;
|
||||
virgl_apir_callbacks * iface;
|
||||
};
|
||||
|
||||
typedef uint32_t (*backend_dispatch_t)(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
|
||||
#include "backend-dispatched.gen.h"
|
||||
|
||||
uint32_t backend_dispatch_initialize(void * ggml_backend_reg_fct_p);
|
||||
32
ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h
Normal file
32
ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h
Normal file
@@ -0,0 +1,32 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "shared/api_remoting.h"
|
||||
|
||||
#include <cstdarg>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
|
||||
extern ggml_backend_reg_t reg;
|
||||
extern ggml_backend_dev_t dev;
|
||||
extern ggml_backend_t bck;
|
||||
|
||||
struct virgl_apir_callbacks {
|
||||
const char * (*get_config)(uint32_t virgl_ctx_id, const char * key);
|
||||
void * (*get_shmem_ptr)(uint32_t virgl_ctx_id, uint32_t res_id);
|
||||
};
|
||||
|
||||
extern "C" {
|
||||
ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks *virgl_cbs);
|
||||
void apir_backend_deinit(uint32_t virgl_ctx_id);
|
||||
uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id,
|
||||
virgl_apir_callbacks * virgl_cbs,
|
||||
uint32_t cmd_type,
|
||||
char * dec_cur,
|
||||
const char * dec_end,
|
||||
char * enc_cur,
|
||||
const char * enc_end,
|
||||
char ** enc_cur_after);
|
||||
}
|
||||
148
ggml/src/ggml-virtgpu/backend/backend.cpp
Normal file
148
ggml/src/ggml-virtgpu/backend/backend.cpp
Normal file
@@ -0,0 +1,148 @@
|
||||
#include "backend-dispatched.h"
|
||||
#include "backend-virgl-apir.h"
|
||||
|
||||
#include "shared/api_remoting.h"
|
||||
#include "shared/apir_backend.h"
|
||||
#include "shared/apir_cs.h"
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <ggml-backend.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#define APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV "APIR_LLAMA_CPP_GGML_LIBRARY_PATH"
|
||||
#define APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV "APIR_LLAMA_CPP_GGML_LIBRARY_REG"
|
||||
#define APIR_LLAMA_CPP_LOG_TO_FILE_ENV "APIR_LLAMA_CPP_LOG_TO_FILE"
|
||||
|
||||
#define GGML_DEFAULT_BACKEND_REG "ggml_backend_init"
|
||||
|
||||
static void * backend_library_handle = NULL;
|
||||
static FILE * apir_logfile = NULL;
|
||||
|
||||
static void log_to_file_callback(enum ggml_log_level level, const char * text, void * user_data) {
|
||||
FILE * logfile = (FILE *)user_data;
|
||||
fprintf(logfile, "[%d] %s", level, text);
|
||||
fflush(logfile);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
void apir_backend_deinit(uint32_t virgl_ctx_id) {
|
||||
GGML_UNUSED(virgl_ctx_id);
|
||||
|
||||
auto buffers = apir_get_track_backend_buffers();
|
||||
for (const auto & buffer : buffers) {
|
||||
apir_untrack_backend_buffer(buffer);
|
||||
buffer->iface.free_buffer(buffer);
|
||||
}
|
||||
|
||||
if (dev) {
|
||||
size_t free, total;
|
||||
dev->iface.get_memory(dev, &free, &total);
|
||||
GGML_LOG_INFO("%s: free memory: %ld MB\n", __func__, (size_t) free / 1024 / 1024);
|
||||
}
|
||||
|
||||
if (backend_library_handle) {
|
||||
GGML_LOG_INFO("%s: The GGML backend library was loaded. Unloading it.\n", __func__);
|
||||
dlclose(backend_library_handle);
|
||||
backend_library_handle = NULL;
|
||||
}
|
||||
|
||||
if (apir_logfile) {
|
||||
fclose(apir_logfile);
|
||||
apir_logfile = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
#define APIR_GGML_LIBRARY_PATH_KEY "ggml.library.path"
|
||||
#define APIR_GGML_LIBRARY_REG_KEY "ggml.library.reg"
|
||||
|
||||
ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks *virgl_cbs) {
|
||||
const char * dlsym_error;
|
||||
|
||||
const char * apir_log_to_file = getenv(APIR_LLAMA_CPP_LOG_TO_FILE_ENV);
|
||||
if (apir_log_to_file) {
|
||||
apir_logfile = fopen(apir_log_to_file, "w");
|
||||
if (apir_logfile) {
|
||||
ggml_log_set(log_to_file_callback, apir_logfile);
|
||||
} else {
|
||||
GGML_LOG_INFO("Could not open the log file at '%s'\n", apir_log_to_file);
|
||||
}
|
||||
}
|
||||
|
||||
const char * library_name = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_PATH_KEY);
|
||||
const char * virgl_library_reg = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_REG_KEY);
|
||||
const char * library_reg = virgl_library_reg ? virgl_library_reg : GGML_DEFAULT_BACKEND_REG;
|
||||
|
||||
if (!library_name) {
|
||||
GGML_LOG_ERROR("cannot open the GGML library: env var '%s' not defined\n", APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV);
|
||||
|
||||
return APIR_LOAD_LIBRARY_ENV_VAR_MISSING;
|
||||
}
|
||||
|
||||
backend_library_handle = dlopen(library_name, RTLD_LAZY);
|
||||
|
||||
if (!backend_library_handle) {
|
||||
GGML_LOG_ERROR("cannot open the GGML library: %s\n", dlerror());
|
||||
|
||||
return APIR_LOAD_LIBRARY_CANNOT_OPEN;
|
||||
}
|
||||
|
||||
if (!library_reg) {
|
||||
GGML_LOG_ERROR("cannot register the GGML library: env var '%s' not defined\n", APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV);
|
||||
|
||||
return APIR_LOAD_LIBRARY_ENV_VAR_MISSING;
|
||||
}
|
||||
|
||||
void * ggml_backend_reg_fct = dlsym(backend_library_handle, library_reg);
|
||||
dlsym_error = dlerror();
|
||||
if (dlsym_error) {
|
||||
GGML_LOG_ERROR("cannot find the GGML backend registration symbol '%s' (from %s): %s\n", library_reg,
|
||||
APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV, dlsym_error);
|
||||
|
||||
return APIR_LOAD_LIBRARY_SYMBOL_MISSING;
|
||||
}
|
||||
|
||||
uint32_t ret = backend_dispatch_initialize(ggml_backend_reg_fct);
|
||||
|
||||
return (ApirLoadLibraryReturnCode) (APIR_LOAD_LIBRARY_INIT_BASE_INDEX + ret);
|
||||
}
|
||||
|
||||
uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id,
|
||||
virgl_apir_callbacks * virgl_cbs,
|
||||
uint32_t cmd_type,
|
||||
char * dec_cur,
|
||||
const char * dec_end,
|
||||
char * enc_cur,
|
||||
const char * enc_end,
|
||||
char ** enc_cur_after) {
|
||||
apir_encoder enc = {
|
||||
.cur = enc_cur,
|
||||
.start = enc_cur,
|
||||
.end = enc_end,
|
||||
.fatal = false,
|
||||
};
|
||||
|
||||
apir_decoder dec = {
|
||||
.cur = dec_cur,
|
||||
.end = dec_end,
|
||||
.fatal = false,
|
||||
};
|
||||
|
||||
virgl_apir_context ctx = {
|
||||
.ctx_id = virgl_ctx_id,
|
||||
.iface = virgl_cbs,
|
||||
};
|
||||
|
||||
if (cmd_type >= APIR_BACKEND_DISPATCH_TABLE_COUNT) {
|
||||
GGML_LOG_ERROR("Received an invalid dispatch index (%d >= %d)\n", cmd_type, APIR_BACKEND_DISPATCH_TABLE_COUNT);
|
||||
return APIR_BACKEND_FORWARD_INDEX_INVALID;
|
||||
}
|
||||
|
||||
backend_dispatch_t forward_fct = apir_backend_dispatch_table[cmd_type];
|
||||
uint32_t ret = forward_fct(&enc, &dec, &ctx);
|
||||
|
||||
*enc_cur_after = enc.cur;
|
||||
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
90
ggml/src/ggml-virtgpu/backend/shared/api_remoting.h
Normal file
90
ggml/src/ggml-virtgpu/backend/shared/api_remoting.h
Normal file
@@ -0,0 +1,90 @@
|
||||
#pragma once
|
||||
|
||||
/* the rest of this file must match virglrenderer/src/apir-protocol.h */
|
||||
|
||||
#include <unistd.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#define APIR_PROTOCOL_MAJOR 0
|
||||
#define APIR_PROTOCOL_MINOR 1
|
||||
|
||||
#define APIR_HANDSHAKE_MAGIC 0xab1e
|
||||
|
||||
enum ApirCommandType {
|
||||
APIR_COMMAND_TYPE_HANDSHAKE = 0,
|
||||
APIR_COMMAND_TYPE_LOADLIBRARY = 1,
|
||||
APIR_COMMAND_TYPE_FORWARD = 2,
|
||||
|
||||
APIR_COMMAND_TYPE_LENGTH = 3,
|
||||
};
|
||||
|
||||
typedef uint64_t ApirCommandFlags;
|
||||
|
||||
enum ApirLoadLibraryReturnCode {
|
||||
APIR_LOAD_LIBRARY_SUCCESS = 0,
|
||||
APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR = 1,
|
||||
APIR_LOAD_LIBRARY_ALREADY_LOADED = 2,
|
||||
APIR_LOAD_LIBRARY_ENV_VAR_MISSING = 3,
|
||||
APIR_LOAD_LIBRARY_CANNOT_OPEN = 4,
|
||||
APIR_LOAD_LIBRARY_SYMBOL_MISSING = 5,
|
||||
APIR_LOAD_LIBRARY_INIT_BASE_INDEX = 6, // anything above this is a APIR backend library initialization return code
|
||||
};
|
||||
|
||||
enum ApirForwardReturnCode {
|
||||
APIR_FORWARD_SUCCESS = 0,
|
||||
APIR_FORWARD_NO_DISPATCH_FCT = 1,
|
||||
APIR_FORWARD_TIMEOUT = 2,
|
||||
|
||||
APIR_FORWARD_BASE_INDEX = 3, // anything above this is a APIR backend library forward return code
|
||||
} ;
|
||||
|
||||
__attribute__((unused)) static inline const char * apir_command_name(ApirCommandType type) {
|
||||
switch (type) {
|
||||
case APIR_COMMAND_TYPE_HANDSHAKE:
|
||||
return "HandShake";
|
||||
case APIR_COMMAND_TYPE_LOADLIBRARY:
|
||||
return "LoadLibrary";
|
||||
case APIR_COMMAND_TYPE_FORWARD:
|
||||
return "Forward";
|
||||
default:
|
||||
return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
__attribute__((unused)) static const char * apir_load_library_error(ApirLoadLibraryReturnCode code) {
|
||||
#define APIR_LOAD_LIBRARY_ERROR(code_name) \
|
||||
do { \
|
||||
if (code == code_name) \
|
||||
return #code_name; \
|
||||
} while (0)
|
||||
|
||||
APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_SUCCESS);
|
||||
APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR);
|
||||
APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_ALREADY_LOADED);
|
||||
APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_ENV_VAR_MISSING);
|
||||
APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_CANNOT_OPEN);
|
||||
APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_SYMBOL_MISSING);
|
||||
APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_INIT_BASE_INDEX);
|
||||
|
||||
return "Unknown APIR_COMMAND_TYPE_LoadLibrary error";
|
||||
|
||||
#undef APIR_LOAD_LIBRARY_ERROR
|
||||
}
|
||||
|
||||
__attribute__((unused)) static const char * apir_forward_error(ApirForwardReturnCode code) {
|
||||
#define APIR_FORWARD_ERROR(code_name) \
|
||||
do { \
|
||||
if (code == code_name) \
|
||||
return #code_name; \
|
||||
} while (0)
|
||||
|
||||
APIR_FORWARD_ERROR(APIR_FORWARD_SUCCESS);
|
||||
APIR_FORWARD_ERROR(APIR_FORWARD_NO_DISPATCH_FCT);
|
||||
APIR_FORWARD_ERROR(APIR_FORWARD_TIMEOUT);
|
||||
APIR_FORWARD_ERROR(APIR_FORWARD_BASE_INDEX);
|
||||
|
||||
return "Unknown APIR_COMMAND_TYPE_FORWARD error";
|
||||
|
||||
#undef APIR_FORWARD_ERROR
|
||||
}
|
||||
36
ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h
Normal file
36
ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h
Normal file
@@ -0,0 +1,36 @@
|
||||
typedef enum ApirBackendCommandType {
|
||||
|
||||
/* device */
|
||||
APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT = 0,
|
||||
APIR_COMMAND_TYPE_DEVICE_GET_COUNT = 1,
|
||||
APIR_COMMAND_TYPE_DEVICE_GET_NAME = 2,
|
||||
APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION = 3,
|
||||
APIR_COMMAND_TYPE_DEVICE_GET_TYPE = 4,
|
||||
APIR_COMMAND_TYPE_DEVICE_GET_MEMORY = 5,
|
||||
APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP = 6,
|
||||
APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE = 7,
|
||||
APIR_COMMAND_TYPE_DEVICE_GET_PROPS = 8,
|
||||
APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR = 9,
|
||||
|
||||
/* buffer-type */
|
||||
APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME = 10,
|
||||
APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT = 11,
|
||||
APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE = 12,
|
||||
APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST = 13,
|
||||
APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER = 14,
|
||||
APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE = 15,
|
||||
|
||||
/* buffer */
|
||||
APIR_COMMAND_TYPE_BUFFER_GET_BASE = 16,
|
||||
APIR_COMMAND_TYPE_BUFFER_SET_TENSOR = 17,
|
||||
APIR_COMMAND_TYPE_BUFFER_GET_TENSOR = 18,
|
||||
APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR = 19,
|
||||
APIR_COMMAND_TYPE_BUFFER_CLEAR = 20,
|
||||
APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER = 21,
|
||||
|
||||
/* backend */
|
||||
APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE = 22,
|
||||
|
||||
// last command_type index + 1
|
||||
APIR_BACKEND_DISPATCH_TABLE_COUNT = 23,
|
||||
} ApirBackendCommandType;
|
||||
46
ggml/src/ggml-virtgpu/backend/shared/apir_backend.h
Normal file
46
ggml/src/ggml-virtgpu/backend/shared/apir_backend.h
Normal file
@@ -0,0 +1,46 @@
|
||||
#pragma once
|
||||
|
||||
#include "apir_backend.gen.h"
|
||||
|
||||
#include <stdint.h> // for uintptr_t
|
||||
#include <time.h> // for timespec, clock_gettime
|
||||
|
||||
#define APIR_BACKEND_INITIALIZE_SUCCESS 0
|
||||
#define APIR_BACKEND_INITIALIZE_CANNOT_OPEN_BACKEND_LIBRARY 1
|
||||
#define APIR_BACKEND_INITIALIZE_CANNOT_OPEN_GGML_LIBRARY 2
|
||||
#define APIR_BACKEND_INITIALIZE_MISSING_BACKEND_SYMBOLS 3
|
||||
#define APIR_BACKEND_INITIALIZE_MISSING_GGML_SYMBOLS 4
|
||||
#define APIR_BACKEND_INITIALIZE_BACKEND_FAILED 5
|
||||
#define APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED 6
|
||||
#define APIR_BACKEND_INITIALIZE_ALREADY_INITED 7
|
||||
#define APIR_BACKEND_INITIALIZE_NO_DEVICE 8
|
||||
|
||||
|
||||
// new entries here need to be added to the apir_backend_initialize_error function below
|
||||
|
||||
#define APIR_BACKEND_FORWARD_INDEX_INVALID 6
|
||||
|
||||
// 0 is fast, 1 avoids the backend to crash if an unsupported tensor is received
|
||||
#define APIR_BACKEND_CHECK_SUPPORTS_OP 0
|
||||
|
||||
typedef uintptr_t apir_buffer_type_host_handle_t;
|
||||
typedef uintptr_t apir_buffer_host_handle_t;
|
||||
|
||||
static const char * apir_backend_initialize_error(int code) {
|
||||
#define APIR_BACKEND_INITIALIZE_ERROR(code_name) \
|
||||
do { \
|
||||
if (code == code_name) \
|
||||
return #code_name; \
|
||||
} while (0)
|
||||
|
||||
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_SUCCESS);
|
||||
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_CANNOT_OPEN_BACKEND_LIBRARY);
|
||||
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_CANNOT_OPEN_GGML_LIBRARY);
|
||||
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_BACKEND_SYMBOLS);
|
||||
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_GGML_SYMBOLS);
|
||||
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_FAILED);
|
||||
|
||||
return "Unknown APIR_BACKEND_INITIALIZE error:/";
|
||||
|
||||
#undef APIR_BACKEND_INITIALIZE_ERROR
|
||||
}
|
||||
383
ggml/src/ggml-virtgpu/backend/shared/apir_cs.h
Normal file
383
ggml/src/ggml-virtgpu/backend/shared/apir_cs.h
Normal file
@@ -0,0 +1,383 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
|
||||
#define likely(x) __builtin_expect(!!(x), 1)
|
||||
#define unlikely(x) __builtin_expect(!!(x), 0)
|
||||
|
||||
struct apir_encoder {
|
||||
char * cur;
|
||||
const char * start;
|
||||
const char * end;
|
||||
bool fatal;
|
||||
|
||||
};
|
||||
|
||||
struct apir_decoder {
|
||||
const char * cur;
|
||||
const char * end;
|
||||
bool fatal;
|
||||
};
|
||||
|
||||
/*
|
||||
* new encoder and decoder
|
||||
*/
|
||||
|
||||
static apir_decoder apir_new_decoder(const char * ptr, size_t size) {
|
||||
apir_decoder dec = {
|
||||
.cur = ptr,
|
||||
.end = ptr + size,
|
||||
.fatal = false,
|
||||
};
|
||||
|
||||
return dec;
|
||||
}
|
||||
|
||||
static apir_encoder apir_new_encoder(char * ptr, size_t size) {
|
||||
apir_encoder enc = {
|
||||
.cur = ptr,
|
||||
.start = ptr,
|
||||
.end = ptr + size,
|
||||
.fatal = false,
|
||||
};
|
||||
|
||||
return enc;
|
||||
}
|
||||
|
||||
/*
|
||||
* fatal flag handling
|
||||
*/
|
||||
|
||||
static inline void apir_encoder_reset_fatal(apir_encoder * enc) {
|
||||
enc->fatal = false;
|
||||
}
|
||||
|
||||
static inline void apir_encoder_set_fatal(apir_encoder * enc) {
|
||||
enc->fatal = true;
|
||||
}
|
||||
|
||||
static inline bool apir_encoder_get_fatal(const apir_encoder * enc) {
|
||||
return enc->fatal;
|
||||
}
|
||||
|
||||
static inline void apir_decoder_reset_fatal(apir_decoder * dec) {
|
||||
dec->fatal = false;
|
||||
}
|
||||
|
||||
static inline void apir_decoder_set_fatal(apir_decoder * dec) {
|
||||
dec->fatal = true;
|
||||
}
|
||||
|
||||
static inline bool apir_decoder_get_fatal(const apir_decoder * dec) {
|
||||
return dec->fatal;
|
||||
}
|
||||
|
||||
/*
|
||||
* encode peek
|
||||
*/
|
||||
|
||||
static inline bool apir_decoder_peek_internal(apir_decoder * dec,
|
||||
size_t size,
|
||||
void * val,
|
||||
size_t val_size) {
|
||||
assert(val_size <= size);
|
||||
|
||||
if (unlikely(size > (size_t) (dec->end - dec->cur))) {
|
||||
GGML_LOG_ERROR("reading too much from the decoder ...\n");
|
||||
apir_decoder_set_fatal(dec);
|
||||
memset(val, 0, val_size);
|
||||
return false;
|
||||
}
|
||||
|
||||
/* we should not rely on the compiler to optimize away memcpy... */
|
||||
memcpy(val, dec->cur, val_size);
|
||||
return true;
|
||||
}
|
||||
|
||||
static inline void apir_decoder_peek(apir_decoder * dec, size_t size, void * val, size_t val_size) {
|
||||
apir_decoder_peek_internal(dec, size, val, val_size);
|
||||
}
|
||||
|
||||
static inline const void * apir_decoder_use_inplace(apir_decoder * dec, size_t size) {
|
||||
if (unlikely(size > (size_t) (dec->end - dec->cur))) {
|
||||
GGML_LOG_ERROR("reading too much from the decoder ...\n");
|
||||
apir_decoder_set_fatal(dec);
|
||||
return NULL;
|
||||
}
|
||||
const void * addr = dec->cur;
|
||||
dec->cur += size;
|
||||
|
||||
return addr;
|
||||
}
|
||||
|
||||
/*
|
||||
* read/write
|
||||
*/
|
||||
|
||||
static inline void apir_decoder_read(apir_decoder * dec, size_t size, void * val, size_t val_size) {
|
||||
if (apir_decoder_peek_internal(dec, size, val, val_size)) {
|
||||
dec->cur += size;
|
||||
}
|
||||
}
|
||||
|
||||
static inline char * apir_encoder_write(apir_encoder * enc, size_t size, const void * val, size_t val_size) {
|
||||
assert(val_size <= size);
|
||||
assert(size <= ((size_t) (enc->end - enc->cur)));
|
||||
|
||||
char * write_addr = enc->cur;
|
||||
/* we should not rely on the compiler to optimize away memcpy... */
|
||||
memcpy(write_addr, val, val_size);
|
||||
enc->cur += size;
|
||||
|
||||
return write_addr;
|
||||
}
|
||||
|
||||
/*
|
||||
* encode/decode
|
||||
*/
|
||||
|
||||
static inline void apir_decode(apir_decoder * dec, size_t size, void * data, size_t data_size) {
|
||||
assert(size % 4 == 0);
|
||||
apir_decoder_read(dec, size, data, data_size);
|
||||
}
|
||||
|
||||
static inline void apir_encode(apir_encoder * enc, size_t size, const void * data, size_t data_size) {
|
||||
assert(size % 4 == 0);
|
||||
apir_encoder_write(enc, size, data, data_size);
|
||||
}
|
||||
|
||||
/*
|
||||
* typed encode/decode
|
||||
*/
|
||||
|
||||
/* uint8_t */
|
||||
|
||||
static inline void apir_encode_uint8_t(apir_encoder * enc, const uint8_t * val) {
|
||||
apir_encode(enc, sizeof(int), val, sizeof(*val));
|
||||
}
|
||||
|
||||
static inline void apir_decode_uint8_t(apir_decoder * dec, uint8_t * val) {
|
||||
apir_decode(dec, sizeof(int), val, sizeof(*val));
|
||||
}
|
||||
|
||||
/* uint64_t */
|
||||
|
||||
static inline void apir_encode_uint64_t(apir_encoder * enc, const uint64_t * val) {
|
||||
apir_encode(enc, 8, val, sizeof(*val));
|
||||
}
|
||||
|
||||
static inline void apir_decode_uint64_t(apir_decoder * dec, uint64_t * val) {
|
||||
apir_decode(dec, 8, val, sizeof(*val));
|
||||
}
|
||||
|
||||
static inline void apir_encode_uint64_t_array(apir_encoder * enc, const uint64_t * val, uint32_t count) {
|
||||
const size_t size = sizeof(*val) * count;
|
||||
assert(size >= count);
|
||||
apir_encode(enc, size, val, size);
|
||||
}
|
||||
|
||||
static inline void apir_decode_uint64_t_array(apir_decoder * dec, uint64_t * val, uint32_t count) {
|
||||
const size_t size = sizeof(*val) * count;
|
||||
assert(size >= count);
|
||||
apir_decode(dec, size, val, size);
|
||||
}
|
||||
|
||||
static inline const uint64_t * apir_decode_uint64_t_array_inplace(apir_decoder * dec, uint32_t count) {
|
||||
return (uint64_t *) (uintptr_t) apir_decoder_use_inplace(dec, count * sizeof(uint64_t));
|
||||
}
|
||||
|
||||
/* int32_t */
|
||||
|
||||
static inline void apir_encode_int32_t(apir_encoder * enc, const int32_t * val) {
|
||||
apir_encode(enc, 4, val, sizeof(*val));
|
||||
}
|
||||
|
||||
static inline void apir_decode_int32_t(apir_decoder * dec, int32_t * val) {
|
||||
apir_decode(dec, 4, val, sizeof(*val));
|
||||
}
|
||||
|
||||
static inline void apir_encode_int32_t_array(apir_encoder * enc, const int32_t * val, uint32_t count) {
|
||||
const size_t size = sizeof(*val) * count;
|
||||
assert(size >= count);
|
||||
apir_encode(enc, size, val, size);
|
||||
}
|
||||
|
||||
static inline void apir_decode_int32_t_array(apir_decoder * dec, int32_t * val, uint32_t count) {
|
||||
const size_t size = sizeof(*val) * count;
|
||||
assert(size >= count);
|
||||
apir_decode(dec, size, val, size);
|
||||
}
|
||||
|
||||
/* array size (uint64_t) */
|
||||
|
||||
static inline void apir_encode_array_size(apir_encoder * enc, uint64_t size) {
|
||||
apir_encode_uint64_t(enc, &size);
|
||||
}
|
||||
|
||||
static inline uint64_t apir_decode_array_size(apir_decoder * dec, uint64_t expected_size) {
|
||||
uint64_t size;
|
||||
apir_decode_uint64_t(dec, &size);
|
||||
if (size != expected_size) {
|
||||
GGML_LOG_ERROR("Couldn't decode array from the decoder\n");
|
||||
apir_decoder_set_fatal(dec);
|
||||
size = 0;
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
static inline uint64_t apir_decode_array_size_unchecked(apir_decoder * dec) {
|
||||
uint64_t size;
|
||||
apir_decode_uint64_t(dec, &size);
|
||||
return size;
|
||||
}
|
||||
|
||||
/* non-array pointer */
|
||||
|
||||
static inline bool apir_encode_simple_pointer(apir_encoder * enc, const void * val) {
|
||||
apir_encode_array_size(enc, val ? 1 : 0);
|
||||
return val;
|
||||
}
|
||||
|
||||
static inline bool apir_decode_simple_pointer(apir_decoder * dec) {
|
||||
return apir_decode_array_size_unchecked(dec);
|
||||
}
|
||||
|
||||
/* uint32_t */
|
||||
|
||||
static inline void apir_encode_uint32_t(apir_encoder * enc, const uint32_t * val) {
|
||||
apir_encode(enc, 4, val, sizeof(*val));
|
||||
}
|
||||
|
||||
static inline void apir_decode_uint32_t(apir_decoder * dec, uint32_t * val) {
|
||||
apir_decode(dec, 4, val, sizeof(*val));
|
||||
}
|
||||
|
||||
static inline void apir_encode_uint32_t_array(apir_encoder * enc, const uint32_t * val, uint32_t count) {
|
||||
const size_t size = sizeof(*val) * count;
|
||||
assert(size >= count);
|
||||
apir_encode(enc, size, val, size);
|
||||
}
|
||||
|
||||
static inline void apir_decode_uint32_t_array(apir_decoder * dec, uint32_t * val, uint32_t count) {
|
||||
const size_t size = sizeof(*val) * count;
|
||||
assert(size >= count);
|
||||
apir_decode(dec, size, val, size);
|
||||
}
|
||||
|
||||
/* size_t */
|
||||
|
||||
static inline void apir_encode_size_t(apir_encoder * enc, const size_t * val) {
|
||||
const uint64_t tmp = *val;
|
||||
apir_encode_uint64_t(enc, &tmp);
|
||||
}
|
||||
|
||||
static inline void apir_decode_size_t(apir_decoder * dec, size_t * val) {
|
||||
uint64_t tmp;
|
||||
apir_decode_uint64_t(dec, &tmp);
|
||||
*val = tmp;
|
||||
}
|
||||
|
||||
static inline void apir_encode_size_t_array(apir_encoder * enc, const size_t * val, uint32_t count) {
|
||||
if (sizeof(size_t) == sizeof(uint64_t)) {
|
||||
apir_encode_uint64_t_array(enc, (const uint64_t *) val, count);
|
||||
} else {
|
||||
for (uint32_t i = 0; i < count; i++) {
|
||||
apir_encode_size_t(enc, &val[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline void apir_decode_size_t_array(apir_decoder * dec, size_t * val, uint32_t count) {
|
||||
if (sizeof(size_t) == sizeof(uint64_t)) {
|
||||
apir_decode_uint64_t_array(dec, (uint64_t *) val, count);
|
||||
} else {
|
||||
for (uint32_t i = 0; i < count; i++) {
|
||||
apir_decode_size_t(dec, &val[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* opaque blob */
|
||||
|
||||
static inline void apir_encode_blob_array(apir_encoder * enc, const void * val, size_t size) {
|
||||
apir_encode(enc, (size + 3) & ~3, val, size);
|
||||
}
|
||||
|
||||
static inline void apir_decode_blob_array(apir_decoder * dec, void * val, size_t size) {
|
||||
apir_decode(dec, (size + 3) & ~3, val, size);
|
||||
}
|
||||
|
||||
/* string */
|
||||
|
||||
static inline void apir_encode_char_array(apir_encoder * enc, const char * val, size_t size) {
|
||||
assert(size && strlen(val) < size);
|
||||
apir_encode_blob_array(enc, val, size);
|
||||
}
|
||||
|
||||
static inline void apir_decode_char_array(apir_decoder * dec, char * val, size_t size) {
|
||||
apir_decode_blob_array(dec, val, size);
|
||||
if (size) {
|
||||
val[size - 1] = '\0';
|
||||
} else {
|
||||
GGML_LOG_ERROR("Couldn't decode the blog array\n");
|
||||
apir_decoder_set_fatal(dec);
|
||||
}
|
||||
}
|
||||
|
||||
/* (temp) buffer allocation */
|
||||
|
||||
static inline void * apir_decoder_alloc_array(size_t size, size_t count) {
|
||||
size_t alloc_size;
|
||||
if (unlikely(__builtin_mul_overflow(size, count, &alloc_size))) {
|
||||
GGML_LOG_ERROR("overflow in array allocation of %zu * %zu bytes\n", size, count);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return malloc(alloc_size);
|
||||
}
|
||||
|
||||
/* bool */
|
||||
|
||||
static inline void apir_encode_bool_t(apir_encoder * enc, const bool * val) {
|
||||
apir_encode(enc, sizeof(int), val, sizeof(bool));
|
||||
}
|
||||
|
||||
static inline void apir_decode_bool_t(apir_decoder * dec, bool * val) {
|
||||
apir_decode(dec, sizeof(int), val, sizeof(bool));
|
||||
}
|
||||
|
||||
/* apir_buffer_type_host_handle_t */
|
||||
|
||||
static inline void apir_encode_apir_buffer_type_host_handle_t(apir_encoder * enc,
|
||||
const apir_buffer_type_host_handle_t * val) {
|
||||
apir_encode(enc, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t));
|
||||
}
|
||||
|
||||
static inline void apir_decode_apir_buffer_type_host_handle_t(apir_decoder * dec,
|
||||
apir_buffer_type_host_handle_t * val) {
|
||||
apir_decode(dec, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t));
|
||||
}
|
||||
|
||||
/* apir_buffer_host_handle_t */
|
||||
|
||||
static inline void apir_encode_apir_buffer_host_handle_t(apir_encoder * enc,
|
||||
const apir_buffer_host_handle_t * val) {
|
||||
apir_encode(enc, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t));
|
||||
}
|
||||
|
||||
static inline void apir_decode_apir_buffer_host_handle_t(apir_decoder * dec, apir_buffer_host_handle_t * val) {
|
||||
apir_decode(dec, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t));
|
||||
}
|
||||
|
||||
/* uintptr_t */
|
||||
|
||||
static inline void apir_encode_uintptr_t(apir_encoder * enc, const uintptr_t * val) {
|
||||
apir_encode(enc, sizeof(*val), val, sizeof(*val));
|
||||
}
|
||||
|
||||
static inline void apir_decode_uintptr_t(apir_decoder * dec, uintptr_t * val) {
|
||||
apir_decode(dec, sizeof(*val), val, sizeof(*val));
|
||||
}
|
||||
211
ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h
Normal file
211
ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h
Normal file
@@ -0,0 +1,211 @@
|
||||
#include "ggml-impl.h"
|
||||
#include "apir_cs.h"
|
||||
#include "apir_cs_rpc.h"
|
||||
|
||||
// ggml_buffer_to_apir_host_handle(ggml_backend_buffer_t buffer);
|
||||
|
||||
static inline void apir_encode_ggml_buffer_host_handle(apir_encoder * enc,
|
||||
const apir_buffer_host_handle_t * handle);
|
||||
|
||||
static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec);
|
||||
|
||||
/* apir_rpc_tensor */
|
||||
|
||||
static inline void apir_encode_rcp_tensor(apir_encoder * enc, const apir_rpc_tensor * apir_rpc_tensor) {
|
||||
size_t apir_rpc_tensor_size = sizeof(*apir_rpc_tensor);
|
||||
apir_encode(enc, apir_rpc_tensor_size, apir_rpc_tensor, apir_rpc_tensor_size);
|
||||
}
|
||||
|
||||
static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_inplace(apir_decoder * dec) {
|
||||
size_t apir_rpc_tensor_size = sizeof(apir_rpc_tensor);
|
||||
|
||||
return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size);
|
||||
}
|
||||
|
||||
static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_array_inplace(apir_decoder * dec,
|
||||
uint32_t n_tensors) {
|
||||
size_t apir_rpc_tensor_size = sizeof(apir_rpc_tensor) * n_tensors;
|
||||
|
||||
return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size);
|
||||
}
|
||||
|
||||
/* ggml_tensor */
|
||||
|
||||
static inline void apir_encode_ggml_tensor(apir_encoder * enc, const ggml_tensor * tensor) {
|
||||
apir_rpc_tensor serialized = apir_serialize_tensor(tensor);
|
||||
|
||||
apir_encode_rcp_tensor(enc, &serialized);
|
||||
}
|
||||
|
||||
static inline const ggml_tensor * apir_decode_ggml_tensor(apir_decoder * dec) {
|
||||
const apir_rpc_tensor * apir_rpc_tensor = apir_decode_apir_rpc_tensor_inplace(dec);
|
||||
ggml_init_params params{
|
||||
/*.mem_size =*/ ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
ggml_context * ctx = ggml_init(params);
|
||||
|
||||
const ggml_tensor * tensor = apir_deserialize_tensor(ctx, apir_rpc_tensor);
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
/* *** ggml_backend_buffer_type_t *** */
|
||||
|
||||
// ggml_backend_buffer_type_t is a POINTER (to a struct).
|
||||
// Only the host pointer is shared between the host and guest.
|
||||
// The guest stores it in `buft->context`.
|
||||
// The host simply writes the pointer address in the buffer variable.
|
||||
|
||||
static inline void apir_encode_ggml_buffer_type(apir_encoder * enc, ggml_backend_buffer_type_t buft) {
|
||||
apir_buffer_type_host_handle_t handle = ggml_buffer_type_to_apir_handle(buft);
|
||||
apir_encoder_write(enc, sizeof(handle), &handle, sizeof(handle));
|
||||
}
|
||||
|
||||
static inline ggml_backend_buffer_type_t apir_decode_ggml_buffer_type(apir_decoder * dec) {
|
||||
apir_buffer_type_host_handle_t handle;
|
||||
|
||||
apir_decoder_read(dec, sizeof(handle), &handle, sizeof(handle));
|
||||
|
||||
return (ggml_backend_buffer_type_t) handle;
|
||||
}
|
||||
|
||||
static inline apir_buffer_type_host_handle_t apir_decode_apir_buffer_type_host_handle(apir_decoder * dec) {
|
||||
apir_buffer_type_host_handle_t handle;
|
||||
|
||||
apir_decoder_read(dec, sizeof(handle), &handle, sizeof(handle));
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
/* *** ggml_backend_type_t *** */
|
||||
|
||||
// ggml_backend_buffer_t is a POINTER.
|
||||
// same logic as for ggml_backend_buffer_type_t
|
||||
|
||||
static inline void apir_encode_ggml_buffer(apir_encoder * enc, const ggml_backend_buffer_t buffer) {
|
||||
apir_buffer_host_handle_t handle = BUFFER_TO_HOST_HANDLE(buffer);
|
||||
apir_encoder_write(enc, sizeof(handle), &handle, sizeof(handle));
|
||||
}
|
||||
|
||||
static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec) {
|
||||
ggml_backend_buffer_t buffer;
|
||||
size_t buffer_ptr_size = sizeof(buffer);
|
||||
|
||||
apir_decoder_read(dec, buffer_ptr_size, &buffer, buffer_ptr_size);
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/* enum ggml_status */
|
||||
|
||||
static inline void apir_encode_ggml_status(apir_encoder * enc, const ggml_status * status) {
|
||||
apir_encoder_write(enc, sizeof(*status), status, sizeof(*status));
|
||||
}
|
||||
|
||||
static inline void apir_decode_ggml_status(apir_decoder * dec, ggml_status * status) {
|
||||
apir_decoder_read(dec, sizeof(*status), status, sizeof(*status));
|
||||
}
|
||||
|
||||
/* virtgpu_shmem */
|
||||
|
||||
static inline void apir_encode_virtgpu_shmem_res_id(apir_encoder * enc, uint32_t shmem_res_id) {
|
||||
apir_encode_uint32_t(enc, &shmem_res_id);
|
||||
}
|
||||
|
||||
static inline void apir_decode_virtgpu_shmem_res_id(apir_decoder * dec, uint32_t * shmem_res_id) {
|
||||
apir_decode_uint32_t(dec, shmem_res_id);
|
||||
}
|
||||
|
||||
/* ggml_cgraph */
|
||||
|
||||
static inline size_t apir_serialize_ggml_cgraph(ggml_cgraph * cgraph, std::vector<uint8_t> & cgraph_data) {
|
||||
apir_serialize_graph(cgraph, cgraph_data);
|
||||
|
||||
return cgraph_data.size();
|
||||
}
|
||||
|
||||
static inline void apir_encode_cgraph_data(apir_encoder * enc, std::vector<uint8_t> & cgraph_data) {
|
||||
size_t cgraph_size = cgraph_data.size();
|
||||
|
||||
apir_encode(enc, cgraph_size, cgraph_data.data(), cgraph_size);
|
||||
}
|
||||
|
||||
static inline ggml_cgraph * apir_decode_ggml_cgraph(apir_decoder * dec, size_t cgraph_size) {
|
||||
GGML_UNUSED(cgraph_size);
|
||||
|
||||
uint32_t n_nodes;
|
||||
apir_decode_uint32_t(dec, &n_nodes);
|
||||
const uint64_t * nodes = apir_decode_uint64_t_array_inplace(dec, n_nodes);
|
||||
|
||||
uint32_t n_tensors;
|
||||
apir_decode_uint32_t(dec, &n_tensors);
|
||||
const apir_rpc_tensor * tensors = apir_decode_apir_rpc_tensor_array_inplace(dec, n_tensors);
|
||||
|
||||
return apir_deserialize_graph(n_nodes, n_tensors, tensors, nodes);
|
||||
}
|
||||
|
||||
static inline void apir_encode_ggml_buffer_handle(apir_encoder * enc, const apir_buffer_host_handle_t * handle) {
|
||||
apir_encoder_write(enc, sizeof(*handle), &handle, sizeof(*handle));
|
||||
}
|
||||
|
||||
static inline void apir_encode_ggml_tensor_inline(apir_encoder * enc, const ggml_tensor * tensor) {
|
||||
size_t tensor_size = sizeof(*tensor);
|
||||
|
||||
if (tensor->extra) {
|
||||
GGML_ABORT("Cannot pass tensors with extra");
|
||||
}
|
||||
|
||||
if (tensor->src[0] && tensor->buffer) {
|
||||
static int first = 1;
|
||||
if (first) {
|
||||
GGML_LOG_WARN("Cannot pass tensors with src and buffer\n");
|
||||
first = 0;
|
||||
}
|
||||
}
|
||||
|
||||
apir_encoder_write(enc, tensor_size, tensor, tensor_size);
|
||||
|
||||
// tensor->data is a pointer inside the device buffer. No need to touch it
|
||||
// tensor->buffer is a pointer to a buffer. Encoding the buffer handle in sequence.
|
||||
// (could also make a copy of the tensor, and update locally.)
|
||||
|
||||
if (tensor->buffer) {
|
||||
apir_buffer_host_handle_t buffer_handle = ggml_buffer_to_apir_handle(tensor->buffer);
|
||||
apir_encode_ggml_buffer_handle(enc, &buffer_handle);
|
||||
}
|
||||
|
||||
if (tensor->view_src) {
|
||||
apir_encoder_write(enc, tensor_size, tensor->view_src, tensor_size);
|
||||
}
|
||||
|
||||
for (int i = 0; tensor->src[i]; i++) {
|
||||
const ggml_tensor * tensor_src = tensor->src[i];
|
||||
apir_encoder_write(enc, tensor_size, tensor_src, tensor_size);
|
||||
}
|
||||
}
|
||||
|
||||
static inline const ggml_tensor * apir_decode_ggml_tensor_inplace(apir_decoder * dec) {
|
||||
// it safe to remove the `const` qualifier here, we *do* want to
|
||||
// modify the shared memory data to fix the `src` pointers.
|
||||
ggml_tensor * tensor = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor));
|
||||
|
||||
// tensor->data is a pointer inside the device buffer. No need to touch it
|
||||
// tensor->buffer is a pointer to a buffer. Decode the buffer handle encoded in sequence.
|
||||
if (tensor->buffer) {
|
||||
tensor->buffer = apir_decode_ggml_buffer(dec);
|
||||
}
|
||||
|
||||
if (tensor->view_src) {
|
||||
ggml_tensor * tensor_view_src = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor));
|
||||
tensor->view_src = tensor_view_src;
|
||||
}
|
||||
|
||||
for (int i = 0; tensor->src[i]; i++) {
|
||||
ggml_tensor * tensor_src = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor));
|
||||
tensor->src[i] = tensor_src; // overwrite op->src[i] pointer with the actual location of the src tensor
|
||||
}
|
||||
|
||||
return tensor;
|
||||
}
|
||||
54
ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h
Normal file
54
ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h
Normal file
@@ -0,0 +1,54 @@
|
||||
#include "ggml.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
|
||||
// ggml_tensor is serialized into apir_rpc_tensor
|
||||
struct apir_rpc_tensor {
|
||||
uint64_t id;
|
||||
uint32_t type;
|
||||
uint64_t buffer;
|
||||
uint32_t ne[GGML_MAX_DIMS];
|
||||
uint32_t nb[GGML_MAX_DIMS];
|
||||
uint32_t op;
|
||||
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||
int32_t flags;
|
||||
uint64_t src[GGML_MAX_SRC];
|
||||
uint64_t view_src;
|
||||
uint64_t view_offs;
|
||||
uint64_t data;
|
||||
char name[GGML_MAX_NAME];
|
||||
|
||||
char padding[4];
|
||||
};
|
||||
|
||||
/* frontend */
|
||||
|
||||
apir_rpc_tensor apir_serialize_tensor(const ggml_tensor * tensor);
|
||||
|
||||
void apir_serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output);
|
||||
|
||||
/* backend */
|
||||
|
||||
void apir_track_backend_buffer(ggml_backend_buffer_t buffer);
|
||||
bool apir_untrack_backend_buffer(ggml_backend_buffer_t buffer);
|
||||
std::unordered_set<ggml_backend_buffer_t> apir_get_track_backend_buffers();
|
||||
|
||||
void apir_add_tensor(ggml_tensor * tensor,
|
||||
std::vector<apir_rpc_tensor> & tensors,
|
||||
std::unordered_set<ggml_tensor *> & visited);
|
||||
|
||||
ggml_tensor * apir_deserialize_tensor(ggml_context * ctx, const apir_rpc_tensor * tensor);
|
||||
|
||||
ggml_tensor * apir_create_node(uint64_t id,
|
||||
ggml_context * ctx,
|
||||
const std::unordered_map<uint64_t, const apir_rpc_tensor *> & tensor_ptrs,
|
||||
std::unordered_map<uint64_t, ggml_tensor *> & tensor_map);
|
||||
|
||||
ggml_cgraph * apir_deserialize_graph(uint32_t n_nodes,
|
||||
uint32_t n_tensors,
|
||||
const apir_rpc_tensor * tensors,
|
||||
const uint64_t * nodes);
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user