mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-03-12 14:43:22 +02:00
Compare commits
28 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e22cd0aa15 | ||
|
|
96cfc4992c | ||
|
|
ed0007aa32 | ||
|
|
344ee2a38a | ||
|
|
d6e1556499 | ||
|
|
f76565db92 | ||
|
|
43e1cbd6c1 | ||
|
|
107d599952 | ||
|
|
e8bbc736cb | ||
|
|
b518195101 | ||
|
|
e2763a6723 | ||
|
|
0beb8db3a0 | ||
|
|
b2f460bd3c | ||
|
|
5f4cdac385 | ||
|
|
ae87863dc1 | ||
|
|
97c64fbdbd | ||
|
|
d417bc43dd | ||
|
|
35bee031e1 | ||
|
|
451ef08432 | ||
|
|
9b24886f78 | ||
|
|
62b8143ad2 | ||
|
|
d088d5b74f | ||
|
|
cd18a50ea5 | ||
|
|
a976ff081b | ||
|
|
a95047979a | ||
|
|
b283f6d5b3 | ||
|
|
ff52ee964d | ||
|
|
213c4a0b81 |
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@@ -93,7 +93,7 @@ jobs:
|
||||
id: cmake_test
|
||||
run: |
|
||||
cd build
|
||||
ctest -L main --verbose --timeout 900
|
||||
ctest -L main -E "test-llama-archs" --verbose --timeout 900
|
||||
|
||||
macOS-latest-cmake-x64:
|
||||
runs-on: macos-15-intel
|
||||
|
||||
@@ -39,6 +39,7 @@ Before submitting your PR:
|
||||
- For intricate features, consider opening a feature request first to discuss and align expectations
|
||||
- When adding support for a new model or feature, focus on **CPU support only** in the initial PR unless you have a good reason not to. Add support for other backends like CUDA in follow-up PRs
|
||||
- Consider allowing write access to your branch for faster reviews, as reviewers can push commits directly
|
||||
- If you are a new contributor, limit your open PRs to 1.
|
||||
|
||||
After submitting your PR:
|
||||
- Expect requests for modifications to ensure the code meets llama.cpp's standards for quality and long-term maintainability
|
||||
|
||||
@@ -259,6 +259,8 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [llama-swap](https://github.com/mostlygeek/llama-swap) - transparent proxy that adds automatic model switching with llama-server
|
||||
- [Kalavai](https://github.com/kalavai-net/kalavai-client) - Crowdsource end to end LLM deployment at any scale
|
||||
- [llmaz](https://github.com/InftyAI/llmaz) - ☸️ Easy, advanced inference platform for large language models on Kubernetes.
|
||||
- [LLMKube](https://github.com/defilantech/llmkube) - Kubernetes operator for llama.cpp with multi-GPU and Apple Silicon Metal
|
||||
support"
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
@@ -2666,7 +2666,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.out_file = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_FINETUNE}));
|
||||
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_FINETUNE, LLAMA_EXAMPLE_RESULTS}));
|
||||
add_opt(common_arg(
|
||||
{"-ofreq", "--output-frequency"}, "N",
|
||||
string_format("output the imatrix every N iterations (default: %d)", params.n_out_freq),
|
||||
@@ -3607,6 +3607,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
}
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"--check"},
|
||||
string_format("check rather than generate results (default: %s)", params.check ? "true" : "false"),
|
||||
[](common_params & params) {
|
||||
params.check = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_RESULTS}));
|
||||
add_opt(common_arg(
|
||||
{"--save-logits"},
|
||||
string_format("save final logits to files for verification (default: %s)", params.save_logits ? "true" : "false"),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "chat-auto-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
@@ -51,13 +52,15 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
|
||||
bool has_tools =
|
||||
autoparser.tools.format.mode != tool_format::NONE && inputs.tools.is_array() && !inputs.tools.empty();
|
||||
std::string trigger_marker = !autoparser.tools.format.section_start.empty() ? autoparser.tools.format.section_start :
|
||||
autoparser.tools.format.per_call_start;
|
||||
bool include_grammar =
|
||||
has_tools && ((inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO && !trigger_marker.empty()) ||
|
||||
inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED);
|
||||
autoparser.tools.format.per_call_start;
|
||||
|
||||
bool has_response_format = !inputs.json_schema.empty() && inputs.json_schema.is_object();
|
||||
bool include_grammar = has_response_format || (has_tools &&
|
||||
((inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO && !trigger_marker.empty()) ||
|
||||
inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
data.grammar_lazy = !has_response_format && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
@@ -68,7 +71,7 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
|
||||
});
|
||||
|
||||
// Set grammar triggers based on tool section markers (fall back to per-call markers)
|
||||
if (data.grammar_lazy) { // only do triggers on lazy grammar
|
||||
if (data.grammar_lazy) {
|
||||
data.grammar_triggers = {
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, trigger_marker }
|
||||
};
|
||||
@@ -104,8 +107,11 @@ common_peg_arena autoparser::build_parser(const templates_params & inputs) const
|
||||
bool has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
|
||||
|
||||
if (has_response_format) {
|
||||
return ctx.reasoning_parser + p.space() +
|
||||
p.content(p.schema(p.json(), "response-format", inputs.json_schema)) + p.end();
|
||||
auto response_format = p.rule("response-format", p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
|
||||
return ctx.reasoning_parser + p.space() + p.choice({
|
||||
p.literal("```json") + p.space() + response_format + p.space() + p.literal("```"),
|
||||
response_format
|
||||
}) + p.end();
|
||||
}
|
||||
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) {
|
||||
|
||||
@@ -162,7 +162,7 @@ diff_split calculate_diff_split(const std::string & left, const std::string & ri
|
||||
right_fully_consumed = true;
|
||||
}
|
||||
|
||||
auto eat_segment = [](std::string & str, segment & seg) -> std::string { return str.append(seg.value); };
|
||||
auto eat_segment = [](std::string str, const segment & seg) -> std::string { return std::move(str) + seg.value; };
|
||||
|
||||
bool can_have_text_suffix = left_end->type == segment_type::TEXT && right_end->type == segment_type::TEXT;
|
||||
bool can_have_text_prefix = right_start->type == segment_type::TEXT && left_start->type == segment_type::TEXT;
|
||||
|
||||
@@ -167,8 +167,8 @@ void tag_based_peg_mapper::from_ast(const common_peg_ast_arena & arena, const co
|
||||
});
|
||||
}
|
||||
|
||||
tagged_parse_result tagged_peg_parser::parse_and_extract(const std::string & input, bool is_partial) const {
|
||||
common_peg_parse_context ctx(input, is_partial);
|
||||
tagged_parse_result tagged_peg_parser::parse_and_extract(const std::string & input, common_peg_parse_flags extra_flags) const {
|
||||
common_peg_parse_context ctx(input, flags | extra_flags);
|
||||
auto parse_result = arena.parse(ctx);
|
||||
|
||||
tag_based_peg_mapper mapper;
|
||||
@@ -179,11 +179,10 @@ tagged_parse_result tagged_peg_parser::parse_and_extract(const std::string & inp
|
||||
|
||||
tagged_parse_result tagged_peg_parser::parse_anywhere_and_extract(const std::string & input) const {
|
||||
if (input.empty()) {
|
||||
return parse_and_extract(input, false);
|
||||
return parse_and_extract(input);
|
||||
}
|
||||
for (size_t i = 0; i < input.size(); i++) {
|
||||
common_peg_parse_context ctx(input, false);
|
||||
ctx.debug = debug;
|
||||
common_peg_parse_context ctx(input, flags);
|
||||
auto parse_result = arena.parse(ctx, i);
|
||||
if (parse_result.success() || i == input.size() - 1) {
|
||||
tag_based_peg_mapper mapper;
|
||||
@@ -477,6 +476,74 @@ common_peg_parser common_chat_peg_builder::standard_constructed_tools(
|
||||
return force_tool_calls ? section : optional(section);
|
||||
}
|
||||
|
||||
// Python-style tool calls: name(arg1="value1", arg2=123)
|
||||
// Used only by LFM2 for now, so we don't merge it into autoparser
|
||||
common_peg_parser common_chat_peg_builder::python_style_tool_calls(
|
||||
const nlohmann::json & tools,
|
||||
bool parallel_tool_calls) {
|
||||
if (!tools.is_array() || tools.empty()) {
|
||||
return eps();
|
||||
}
|
||||
|
||||
auto tool_choices = choice();
|
||||
|
||||
for (const auto & tool_def : tools) {
|
||||
if (!tool_def.contains("function")) {
|
||||
continue;
|
||||
}
|
||||
const auto & function = tool_def.at("function");
|
||||
std::string name = function.at("name");
|
||||
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
|
||||
|
||||
auto args = eps();
|
||||
if (params.contains("properties") && !params["properties"].empty()) {
|
||||
auto arg_choice = choice();
|
||||
for (const auto & el : params["properties"].items()) {
|
||||
const std::string & prop_name = el.key();
|
||||
const auto & prop_def = el.value();
|
||||
bool is_string_type = (prop_def.contains("type") && prop_def["type"] == "string");
|
||||
|
||||
auto arg_name_parser = literal(prop_name);
|
||||
|
||||
common_peg_parser arg_value_parser = eps();
|
||||
auto string_value_parser = choice({
|
||||
literal("\"") + tool_arg_string_value(json_string_content()) + literal("\""),
|
||||
literal("'") + tool_arg_string_value(json_string_content()) + literal("'")
|
||||
});
|
||||
|
||||
if (is_string_type) {
|
||||
arg_value_parser = string_value_parser;
|
||||
} else {
|
||||
arg_value_parser = tool_arg_value(python_value());
|
||||
}
|
||||
|
||||
// Full argument: name="value" or name=value
|
||||
auto arg_rule = tool_arg(
|
||||
tool_arg_open(eps()) +
|
||||
tool_arg_name(arg_name_parser) +
|
||||
literal("=") +
|
||||
arg_value_parser +
|
||||
tool_arg_close(eps())
|
||||
);
|
||||
arg_choice |= arg_rule;
|
||||
}
|
||||
|
||||
args = arg_choice + zero_or_more("," + space() + arg_choice);
|
||||
}
|
||||
|
||||
auto tool_parser = tool(tool_open(tool_name(literal(name)) + literal("(")) +
|
||||
space() + tool_args(args) + space() + tool_close(literal(")"))
|
||||
);
|
||||
|
||||
tool_choices |= rule("tool-" + name, tool_parser);
|
||||
}
|
||||
|
||||
if (parallel_tool_calls) {
|
||||
return "[" + space() + tool_choices + zero_or_more("," + space() + tool_choices) + space() + "]";
|
||||
}
|
||||
return "[" + space() + tool_choices + space() + "]";
|
||||
}
|
||||
|
||||
// Helper: Parse dot notation key into prefix and field name
|
||||
static std::pair<std::string, std::string> parse_key_spec(const std::string & key) {
|
||||
auto dot_pos = key.find('.');
|
||||
|
||||
@@ -112,6 +112,11 @@ class common_chat_peg_builder : public common_peg_parser_builder {
|
||||
bool parallel_tool_calls,
|
||||
bool force_tool_calls);
|
||||
|
||||
// Helper for Python-style function call format: name(arg1="value1", arg2=123)
|
||||
// Used by LFM2 and similar templates
|
||||
common_peg_parser python_style_tool_calls(const nlohmann::json & tools,
|
||||
bool parallel_tool_calls);
|
||||
|
||||
private:
|
||||
// Implementation helpers for standard_json_tools — one per JSON tool call layout mode
|
||||
common_peg_parser build_json_tools_function_is_key(const nlohmann::json & tools,
|
||||
@@ -155,19 +160,19 @@ struct tagged_parse_result {
|
||||
|
||||
struct tagged_peg_parser {
|
||||
common_peg_arena arena;
|
||||
bool debug = false;
|
||||
common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_NONE;
|
||||
|
||||
tagged_peg_parser & withDebug() {
|
||||
debug = true;
|
||||
flags |= COMMON_PEG_PARSE_FLAG_DEBUG;
|
||||
return *this;
|
||||
}
|
||||
|
||||
tagged_peg_parser & withoutDebug() {
|
||||
debug = false;
|
||||
flags = flags & ~COMMON_PEG_PARSE_FLAG_DEBUG;
|
||||
return *this;
|
||||
}
|
||||
|
||||
tagged_parse_result parse_and_extract(const std::string & input, bool is_partial = false) const;
|
||||
tagged_parse_result parse_and_extract(const std::string & input, common_peg_parse_flags extra_flags = COMMON_PEG_PARSE_FLAG_NONE) const;
|
||||
tagged_parse_result parse_anywhere_and_extract(const std::string & input) const;
|
||||
};
|
||||
|
||||
|
||||
115
common/chat.cpp
115
common/chat.cpp
@@ -129,7 +129,7 @@ json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", tool_call.name},
|
||||
{"arguments", json::parse(tool_call.arguments)},
|
||||
{"arguments", json(tool_call.arguments)},
|
||||
}},
|
||||
};
|
||||
if (!tool_call.id.empty()) {
|
||||
@@ -1274,8 +1274,95 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
||||
return data;
|
||||
}
|
||||
|
||||
// LFM2 format:
|
||||
// - Reasoning: <think>{reasoning}</think> (optional, only if enable_thinking is true)
|
||||
// - Content: text after reasoning (optional)
|
||||
// - Tool calls: <|tool_call_start|>[function_name(arg1="value1", arg2="value2")]<|tool_call_end|>
|
||||
// Tool calls can appear multiple times (parallel tool calls)
|
||||
static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
"<|tool_list_start|>",
|
||||
"<|tool_list_end|>",
|
||||
"<|tool_call_start|>",
|
||||
"<|tool_call_end|>",
|
||||
"<think>",
|
||||
"</think>",
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
|
||||
|
||||
const std::string TOOL_CALL_START = "<|tool_call_start|>";
|
||||
const std::string TOOL_CALL_END = "<|tool_call_end|>";
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
|
||||
auto end = p.end();
|
||||
|
||||
auto reasoning = p.eps();
|
||||
if (extract_reasoning && inputs.enable_thinking) {
|
||||
reasoning = p.optional(THINK_START + p.reasoning(p.until(THINK_END)) + THINK_END);
|
||||
}
|
||||
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return reasoning + p.content(p.rest()) + end;
|
||||
}
|
||||
|
||||
auto tool_calls = p.rule("tool-calls",
|
||||
p.trigger_rule("tool-call", p.literal(TOOL_CALL_START) +
|
||||
p.python_style_tool_calls(inputs.tools, inputs.parallel_tool_calls) +
|
||||
p.literal(TOOL_CALL_END)
|
||||
)
|
||||
);
|
||||
|
||||
auto content = p.content(p.until(TOOL_CALL_START));
|
||||
|
||||
return reasoning + content + tool_calls + end;
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
auto schema = function.at("parameters");
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
data.grammar_triggers = {
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, TOOL_CALL_START }
|
||||
};
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
namespace workaround {
|
||||
|
||||
static void map_developer_role_to_system(json & messages) {
|
||||
for (auto & message : messages) {
|
||||
if (message.contains("role")) {
|
||||
if (message["role"] == "developer") {
|
||||
message["role"] = "system";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// if first message is system and template does not support it, merge it with next message
|
||||
static void system_message_not_supported(json & messages) {
|
||||
if (!messages.empty() && messages.front().at("role") == "system") {
|
||||
@@ -1353,6 +1440,12 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
params.add_bos = tmpls->add_bos;
|
||||
params.add_eos = tmpls->add_eos;
|
||||
|
||||
if (src.find("<|channel|>") == std::string::npos) {
|
||||
// map developer to system for all models except for GPT-OSS
|
||||
workaround::map_developer_role_to_system(params.messages);
|
||||
}
|
||||
workaround::func_args_not_string(params.messages);
|
||||
|
||||
if (!tmpl.original_caps().supports_system_role) {
|
||||
workaround::system_message_not_supported(params.messages);
|
||||
}
|
||||
@@ -1420,6 +1513,14 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
return common_chat_params_init_kimi_k2(tmpl, params);
|
||||
}
|
||||
|
||||
// LFM2 - uses <|tool_list_start|>/<|tool_list_end|> markers and <|tool_call_start|>[name(args)]<|tool_call_end|> format
|
||||
// Detection: template has "<|tool_list_start|>" and "<|tool_list_end|>" markers
|
||||
if (src.find("<|tool_list_start|>") != std::string::npos &&
|
||||
src.find("<|tool_list_end|>") != std::string::npos) {
|
||||
LOG_DBG("Using specialized template: LFM2\n");
|
||||
return common_chat_params_init_lfm2(tmpl, params);
|
||||
}
|
||||
|
||||
try {
|
||||
LOG_DBG("Using differential autoparser\n");
|
||||
struct autoparser::autoparser autoparser;
|
||||
@@ -1525,8 +1626,12 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
||||
|
||||
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), input.c_str());
|
||||
|
||||
common_peg_parse_context ctx(input, is_partial);
|
||||
ctx.debug = params.debug;
|
||||
common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_LENIENT;
|
||||
if (params.debug) {
|
||||
flags |= COMMON_PEG_PARSE_FLAG_DEBUG;
|
||||
}
|
||||
|
||||
common_peg_parse_context ctx(input, flags);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
if (result.fail()) {
|
||||
@@ -1539,7 +1644,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
||||
auto mapper = common_chat_peg_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "\nAST for partial parse (fail):\n%s\n", ctx.ast.dump().c_str());
|
||||
fflush(stderr);
|
||||
}
|
||||
@@ -1555,7 +1660,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
||||
auto mapper = common_chat_peg_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "\nAST for %s parse:\n%s\n", is_partial ? "partial" : "full", ctx.ast.dump().c_str());
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
@@ -104,6 +104,7 @@ enum llama_example {
|
||||
LLAMA_EXAMPLE_DIFFUSION,
|
||||
LLAMA_EXAMPLE_FINETUNE,
|
||||
LLAMA_EXAMPLE_FIT_PARAMS,
|
||||
LLAMA_EXAMPLE_RESULTS,
|
||||
|
||||
LLAMA_EXAMPLE_COUNT,
|
||||
};
|
||||
@@ -456,6 +457,8 @@ struct common_params {
|
||||
|
||||
bool kl_divergence = false; // compute KL divergence
|
||||
|
||||
bool check = false; // check rather than generate results for llama-results
|
||||
|
||||
bool usage = false; // print usage
|
||||
bool completion = false; // print source-able completion script
|
||||
bool use_color = false; // use color to distinguish generations and inputs
|
||||
|
||||
@@ -349,7 +349,7 @@ struct parser_executor {
|
||||
auto pos = start_pos;
|
||||
for (auto i = 0u; i < p.literal.size(); ++i) {
|
||||
if (pos >= ctx.input.size()) {
|
||||
if (!ctx.is_partial) {
|
||||
if (!ctx.is_lenient()) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
|
||||
@@ -364,7 +364,7 @@ struct parser_executor {
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_sequence_parser & p) {
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
LOG_DBG("%sSEQ start at %zu '%s' (%zu children)\n", debug_indent().c_str(), start_pos,
|
||||
debug_input_snippet(start_pos).c_str(), p.children.size());
|
||||
}
|
||||
@@ -375,26 +375,19 @@ struct parser_executor {
|
||||
|
||||
for (size_t i = 0; i < p.children.size(); i++) {
|
||||
const auto & child_id = p.children[i];
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sSEQ child %zu: %s\n", debug_indent().c_str(), i, arena.dump(child_id).c_str());
|
||||
}
|
||||
auto result = arena.parse(child_id, ctx, pos);
|
||||
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sSEQ child %zu: %s at %zu->%zu\n", debug_indent().c_str(), i,
|
||||
common_peg_parse_result_type_name(result.type), result.start, result.end);
|
||||
}
|
||||
|
||||
if (result.fail()) {
|
||||
ctx.parse_depth--;
|
||||
if (ctx.is_partial && result.end >= ctx.input.size()) {
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sSEQ -> NEED_MORE (child failed at end)\n", debug_indent().c_str());
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end,
|
||||
std::move(nodes));
|
||||
}
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sSEQ -> FAIL\n", debug_indent().c_str());
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, result.end);
|
||||
@@ -406,7 +399,7 @@ struct parser_executor {
|
||||
|
||||
if (result.need_more_input()) {
|
||||
ctx.parse_depth--;
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sSEQ -> NEED_MORE\n", debug_indent().c_str());
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end, std::move(nodes));
|
||||
@@ -416,14 +409,14 @@ struct parser_executor {
|
||||
}
|
||||
|
||||
ctx.parse_depth--;
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sSEQ -> SUCCESS at %zu->%zu\n", debug_indent().c_str(), start_pos, pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos, std::move(nodes));
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_choice_parser & p) {
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sCHOICE start at %zu '%s' (%zu options)\n", debug_indent().c_str(), start_pos,
|
||||
debug_input_snippet(start_pos).c_str(), p.children.size());
|
||||
}
|
||||
@@ -432,17 +425,17 @@ struct parser_executor {
|
||||
auto pos = start_pos;
|
||||
for (size_t i = 0; i < p.children.size(); i++) {
|
||||
const auto & child_id = p.children[i];
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sCHOICE option %zu: %s\n", debug_indent().c_str(), i, arena.dump(child_id).c_str());
|
||||
}
|
||||
auto result = arena.parse(child_id, ctx, pos);
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sCHOICE option %zu: %s\n", debug_indent().c_str(), i,
|
||||
common_peg_parse_result_type_name(result.type));
|
||||
}
|
||||
if (!result.fail()) {
|
||||
ctx.parse_depth--;
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sCHOICE -> %s (option %zu)\n", debug_indent().c_str(),
|
||||
common_peg_parse_result_type_name(result.type), i);
|
||||
}
|
||||
@@ -451,14 +444,14 @@ struct parser_executor {
|
||||
}
|
||||
|
||||
ctx.parse_depth--;
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sCHOICE -> FAIL (no options matched)\n", debug_indent().c_str());
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_repetition_parser & p) {
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sREPEAT start at %zu '%s' (min=%d, max=%d)\n", debug_indent().c_str(), start_pos,
|
||||
debug_input_snippet(start_pos).c_str(), p.min_count, p.max_count);
|
||||
}
|
||||
@@ -471,7 +464,7 @@ struct parser_executor {
|
||||
// Try to match up to max_count times (or unlimited if max_count is -1)
|
||||
while (p.max_count == -1 || match_count < p.max_count) {
|
||||
if (pos >= ctx.input.size()) {
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sREPEAT: at end of input, count=%d\n", debug_indent().c_str(), match_count);
|
||||
}
|
||||
break;
|
||||
@@ -479,7 +472,7 @@ struct parser_executor {
|
||||
|
||||
auto result = arena.parse(p.child, ctx, pos);
|
||||
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sREPEAT iter %d: %s at %zu->%zu, nodes=%zu\n", debug_indent().c_str(), match_count,
|
||||
common_peg_parse_result_type_name(result.type), result.start, result.end, result.nodes.size());
|
||||
fprintf(stderr, "%sREPEAT CHILD: %s\n", debug_indent().c_str(), arena.dump(p.child).c_str());
|
||||
@@ -488,7 +481,7 @@ struct parser_executor {
|
||||
if (result.success()) {
|
||||
// Prevent infinite loop on empty matches
|
||||
if (result.end == pos) {
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%s REPEAT: empty match, stopping\n", debug_indent().c_str());
|
||||
}
|
||||
break;
|
||||
@@ -509,7 +502,7 @@ struct parser_executor {
|
||||
}
|
||||
|
||||
ctx.parse_depth--;
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sREPEAT -> NEED_MORE (count=%d, nodes=%zu)\n", debug_indent().c_str(),
|
||||
match_count, nodes.size());
|
||||
}
|
||||
@@ -517,7 +510,7 @@ struct parser_executor {
|
||||
}
|
||||
|
||||
// Child failed - stop trying
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sREPEAT: child failed, stopping\n", debug_indent().c_str());
|
||||
}
|
||||
break;
|
||||
@@ -526,14 +519,14 @@ struct parser_executor {
|
||||
// Check if we got enough matches
|
||||
if (p.min_count > 0 && match_count < p.min_count) {
|
||||
ctx.parse_depth--;
|
||||
if (pos >= ctx.input.size() && ctx.is_partial) {
|
||||
if (ctx.debug) {
|
||||
if (pos >= ctx.input.size() && ctx.is_lenient()) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sREPEAT -> NEED_MORE (not enough matches: %d < %d)\n", debug_indent().c_str(),
|
||||
match_count, p.min_count);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos, std::move(nodes));
|
||||
}
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sREPEAT -> FAIL (not enough matches: %d < %d)\n", debug_indent().c_str(), match_count,
|
||||
p.min_count);
|
||||
}
|
||||
@@ -541,7 +534,7 @@ struct parser_executor {
|
||||
}
|
||||
|
||||
ctx.parse_depth--;
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sREPEAT -> SUCCESS (count=%d, nodes=%zu)\n", debug_indent().c_str(), match_count,
|
||||
nodes.size());
|
||||
}
|
||||
@@ -576,7 +569,7 @@ struct parser_executor {
|
||||
auto result = common_parse_utf8_codepoint(ctx.input, start_pos);
|
||||
|
||||
if (result.status == utf8_parse_result::INCOMPLETE) {
|
||||
if (!ctx.is_partial) {
|
||||
if (!ctx.is_lenient()) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos);
|
||||
@@ -615,7 +608,7 @@ struct parser_executor {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos);
|
||||
}
|
||||
// Not enough matches yet
|
||||
if (!ctx.is_partial) {
|
||||
if (!ctx.is_lenient()) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
|
||||
@@ -656,7 +649,7 @@ struct parser_executor {
|
||||
|
||||
// Check if we got enough matches
|
||||
if (match_count < p.min_count) {
|
||||
if (pos >= ctx.input.size() && ctx.is_partial) {
|
||||
if (pos >= ctx.input.size() && ctx.is_lenient()) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos);
|
||||
@@ -668,7 +661,7 @@ struct parser_executor {
|
||||
static common_peg_parse_result handle_escape_sequence(common_peg_parse_context & ctx, size_t start, size_t & pos) {
|
||||
++pos; // consume '\'
|
||||
if (pos >= ctx.input.size()) {
|
||||
if (!ctx.is_partial) {
|
||||
if (!ctx.is_lenient()) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start, pos);
|
||||
@@ -698,7 +691,7 @@ struct parser_executor {
|
||||
++pos; // consume 'u'
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
if (pos >= ctx.input.size()) {
|
||||
if (!ctx.is_partial) {
|
||||
if (!ctx.is_lenient()) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start, pos);
|
||||
@@ -732,7 +725,7 @@ struct parser_executor {
|
||||
auto utf8_result = common_parse_utf8_codepoint(ctx.input, pos);
|
||||
|
||||
if (utf8_result.status == utf8_parse_result::INCOMPLETE) {
|
||||
if (!ctx.is_partial) {
|
||||
if (!ctx.is_lenient()) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
|
||||
@@ -747,7 +740,7 @@ struct parser_executor {
|
||||
}
|
||||
|
||||
// Reached end without finding closing quote
|
||||
if (!ctx.is_partial) {
|
||||
if (!ctx.is_lenient()) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
|
||||
@@ -774,7 +767,7 @@ struct parser_executor {
|
||||
auto utf8_result = common_parse_utf8_codepoint(ctx.input, pos);
|
||||
|
||||
if (utf8_result.status == utf8_parse_result::INCOMPLETE) {
|
||||
if (!ctx.is_partial) {
|
||||
if (!ctx.is_lenient()) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
|
||||
@@ -789,7 +782,7 @@ struct parser_executor {
|
||||
}
|
||||
|
||||
// Reached end without finding closing quote
|
||||
if (!ctx.is_partial) {
|
||||
if (!ctx.is_lenient()) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
|
||||
@@ -807,7 +800,7 @@ struct parser_executor {
|
||||
|
||||
if (utf8_result.status == utf8_parse_result::INCOMPLETE) {
|
||||
// Incomplete UTF-8 sequence
|
||||
if (!ctx.is_partial) {
|
||||
if (!ctx.is_lenient()) {
|
||||
// Input is complete but UTF-8 is incomplete = malformed
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
|
||||
}
|
||||
@@ -837,7 +830,7 @@ struct parser_executor {
|
||||
last_valid_pos = pos;
|
||||
}
|
||||
|
||||
if (last_valid_pos == ctx.input.size() && ctx.is_partial) {
|
||||
if (last_valid_pos == ctx.input.size() && ctx.is_lenient()) {
|
||||
// Reached the end of a partial stream, there might still be more input that we need to consume.
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, last_valid_pos);
|
||||
}
|
||||
@@ -876,7 +869,7 @@ struct parser_executor {
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_tag_parser & p) {
|
||||
// Parse the child
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sTAG: %s\n", debug_indent().c_str(), p.tag.c_str());
|
||||
}
|
||||
auto result = arena.parse(p.child, ctx, start_pos);
|
||||
|
||||
@@ -139,22 +139,43 @@ struct common_peg_parse_result {
|
||||
bool success() const { return type == COMMON_PEG_PARSE_RESULT_SUCCESS; }
|
||||
};
|
||||
|
||||
enum common_peg_parse_flags {
|
||||
COMMON_PEG_PARSE_FLAG_NONE = 0,
|
||||
COMMON_PEG_PARSE_FLAG_LENIENT = 1 << 0,
|
||||
COMMON_PEG_PARSE_FLAG_DEBUG = 1 << 1,
|
||||
};
|
||||
|
||||
inline common_peg_parse_flags operator|(common_peg_parse_flags a, common_peg_parse_flags b) {
|
||||
return static_cast<common_peg_parse_flags>(int(a) | int(b));
|
||||
}
|
||||
|
||||
inline common_peg_parse_flags & operator|=(common_peg_parse_flags & a, common_peg_parse_flags b) {
|
||||
return a = a | b;
|
||||
}
|
||||
|
||||
inline common_peg_parse_flags operator&(common_peg_parse_flags a, common_peg_parse_flags b) {
|
||||
return static_cast<common_peg_parse_flags>(int(a) & int(b));
|
||||
}
|
||||
|
||||
inline common_peg_parse_flags operator~(common_peg_parse_flags a) {
|
||||
return static_cast<common_peg_parse_flags>(~int(a));
|
||||
}
|
||||
|
||||
struct common_peg_parse_context {
|
||||
std::string input;
|
||||
bool is_partial;
|
||||
bool debug = false; // Enable debug output for parser tracing
|
||||
common_peg_parse_flags flags;
|
||||
common_peg_ast_arena ast;
|
||||
|
||||
int parse_depth;
|
||||
|
||||
common_peg_parse_context()
|
||||
: is_partial(false), parse_depth(0) {}
|
||||
common_peg_parse_context(common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_NONE)
|
||||
: flags(flags), parse_depth(0) {}
|
||||
|
||||
common_peg_parse_context(const std::string & input)
|
||||
: input(input), is_partial(false), parse_depth(0) {}
|
||||
common_peg_parse_context(const std::string & input, common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_NONE)
|
||||
: input(input), flags(flags), parse_depth(0) {}
|
||||
|
||||
common_peg_parse_context(const std::string & input, bool is_partial)
|
||||
: input(input), is_partial(is_partial), parse_depth(0) {}
|
||||
bool is_lenient() const { return flags & COMMON_PEG_PARSE_FLAG_LENIENT; }
|
||||
bool is_debug() const { return flags & COMMON_PEG_PARSE_FLAG_DEBUG; }
|
||||
};
|
||||
|
||||
class common_peg_arena;
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
- [Linux](#linux)
|
||||
- [Windows](#windows)
|
||||
- [Environment Variable](#environment-variable)
|
||||
- [Design Rule](#design-rule)
|
||||
- [Known Issue](#known-issues)
|
||||
- [Q&A](#qa)
|
||||
- [TODO](#todo)
|
||||
@@ -41,6 +42,9 @@ The following releases are verified and recommended:
|
||||
|
||||
## News
|
||||
|
||||
- 2026.03
|
||||
- Support Flash-Attention: less memory usage, performance impact depends on LLM.
|
||||
|
||||
- 2026.02
|
||||
- Remove support for Nvidia & AMD GPU, because the oneAPI plugin for Nvidia & AMD GPU is unavailable: download/installation channels are out of work. User can't build up the software for Nvidia & AMD GPU.
|
||||
|
||||
@@ -685,18 +689,45 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
||||
| Name | Value | Function |
|
||||
|-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
|
||||
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
|
||||
| GGML_SYCL_ENABLE_FLASH_ATTN | 1 (default) or 0| Enable Flash-Attention. It can reduce memory usage. The performance impact depends on the LLM.|
|
||||
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features for Intel GPUs. (Recommended to 1 for intel devices older than Gen 10) |
|
||||
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because SYCL Graph is still on development, no better performance. |
|
||||
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
|
||||
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
|
||||
| UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.|
|
||||
|
||||
## Design Rule
|
||||
|
||||
- Open to all contributors.
|
||||
|
||||
- All code change should be useful to user:
|
||||
- Fix bug.
|
||||
- Add new function.
|
||||
- Improve the performance/usage.
|
||||
- Make code be easy to maintain.
|
||||
- ...
|
||||
|
||||
- Don't accept the codes of following cases:
|
||||
- Break legacy function.
|
||||
- Reduce the performance of legacy case in default.
|
||||
- Not completed work/the functionality cannot be demonstrated.
|
||||
|
||||
- Encourage to use environment variable to control features to be opened/closed.
|
||||
- User can evaluate the feature without rebuild the code.
|
||||
- Recommend the best features to user by setting them be opened as default.
|
||||
|
||||
- Design the code based on the published official releases of oneAPI packages: compiler, library, driver, OS kernel.
|
||||
|
||||
- Developers need to maintain the code they submit.
|
||||
|
||||
## Known Issues
|
||||
|
||||
- `Split-mode:[row]` is not supported.
|
||||
|
||||
- Missed the AOT (Ahead-of-Time) in buiding.
|
||||
- Good: build quickly, smaller size of binary file.
|
||||
- Bad: The startup is slow (JIT) in first time, but subsequent performance is unaffected.
|
||||
|
||||
## Q&A
|
||||
|
||||
- Error: `error while loading shared libraries: libsycl.so: cannot open shared object file: No such file or directory`.
|
||||
|
||||
19
docs/ops.md
19
docs/ops.md
@@ -37,16 +37,17 @@ Legend:
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DIAG | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| DIAG | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GATED_DELTA_NET | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
@@ -54,7 +55,7 @@ Legend:
|
||||
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
@@ -90,9 +91,9 @@ Legend:
|
||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SET_ROWS | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
@@ -100,7 +101,7 @@ Legend:
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
@@ -116,5 +117,5 @@ Legend:
|
||||
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
|
||||
23688
docs/ops/SYCL.csv
23688
docs/ops/SYCL.csv
File diff suppressed because it is too large
Load Diff
9167
docs/ops/Vulkan.csv
9167
docs/ops/Vulkan.csv
File diff suppressed because it is too large
Load Diff
@@ -205,7 +205,14 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
|
||||
|
||||
int64_t total_vram = 0;
|
||||
GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
|
||||
for (int id = 0; id < info.device_count; ++id) {
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
|
||||
total_vram += prop.totalGlobalMem;
|
||||
}
|
||||
GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices (Total VRAM: %zu MiB):\n",
|
||||
__func__, info.device_count, (size_t)(total_vram / (1024 * 1024)));
|
||||
total_vram = 0;
|
||||
|
||||
std::vector<std::pair<int, std::string>> turing_devices_without_mma;
|
||||
for (int id = 0; id < info.device_count; ++id) {
|
||||
@@ -243,6 +250,12 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
#else
|
||||
info.devices[id].supports_cooperative_launch = false;
|
||||
#endif // !(GGML_USE_MUSA)
|
||||
|
||||
// cudaMemGetInfo returns info for the current device
|
||||
size_t free_mem;
|
||||
CUDA_CHECK(cudaSetDevice(id));
|
||||
CUDA_CHECK(cudaMemGetInfo(&free_mem, NULL));
|
||||
|
||||
#if defined(GGML_USE_HIP)
|
||||
info.devices[id].smpbo = prop.sharedMemPerBlock;
|
||||
|
||||
@@ -257,22 +270,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
info.devices[id].cc += prop.minor * 0x10;
|
||||
}
|
||||
}
|
||||
GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n",
|
||||
GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d, VRAM: %zu MiB (%zu MiB free)\n",
|
||||
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
|
||||
device_vmm ? "yes" : "no", prop.warpSize);
|
||||
device_vmm ? "yes" : "no", prop.warpSize,
|
||||
(size_t)(prop.totalGlobalMem / (1024 * 1024)), free_mem / (1024 * 1024));
|
||||
#elif defined(GGML_USE_MUSA)
|
||||
// FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
|
||||
info.devices[id].warp_size = 32;
|
||||
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
|
||||
info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100;
|
||||
info.devices[id].cc += prop.minor * 0x10;
|
||||
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
|
||||
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
|
||||
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB (%zu MiB free)\n",
|
||||
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
|
||||
(size_t)(prop.totalGlobalMem / (1024 * 1024)), free_mem / (1024 * 1024));
|
||||
#else
|
||||
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
|
||||
info.devices[id].cc = 100*prop.major + 10*prop.minor;
|
||||
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
|
||||
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
|
||||
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB (%zu MiB free)\n",
|
||||
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
|
||||
(size_t)(prop.totalGlobalMem / (1024 * 1024)), free_mem / (1024 * 1024));
|
||||
std::string device_name(prop.name);
|
||||
if (device_name == "NVIDIA GeForce MX450") {
|
||||
turing_devices_without_mma.push_back({ id, device_name });
|
||||
@@ -4976,9 +4992,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
return true;
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
//TODO: enable once MUSA compiler is solved https://github.com/ggml-org/llama.cpp/pull/19504#issuecomment-4018634327
|
||||
#ifdef GGML_USE_MUSA
|
||||
return false;
|
||||
#else
|
||||
return true;
|
||||
#endif // GGML_USE_MUSA
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
|
||||
@@ -1717,12 +1717,29 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_met
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
|
||||
const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
|
||||
|
||||
const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);
|
||||
|
||||
if (mode == GGML_SCALE_MODE_BILINEAR) {
|
||||
snprintf(base, 256, "kernel_upscale_bilinear_%s", ggml_type_name(op->src[0]->type));
|
||||
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
|
||||
snprintf(base, 256, "kernel_upscale_bicubic_%s", ggml_type_name(op->src[0]->type));
|
||||
} else {
|
||||
snprintf(base, 256, "kernel_upscale_nearest_%s", ggml_type_name(op->src[0]->type));
|
||||
}
|
||||
snprintf(name, 256, "%s_aa=%d", base, antialias);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_bool(cv, antialias, FC_UPSCALE + 0);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
}
|
||||
|
||||
return res;
|
||||
|
||||
@@ -1108,7 +1108,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
op->type == GGML_TYPE_F32 &&
|
||||
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_POOL_1D:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_POOL_2D:
|
||||
|
||||
@@ -83,6 +83,7 @@
|
||||
#define FC_UNARY 1200
|
||||
#define FC_BIN 1300
|
||||
#define FC_SUM_ROWS 1400
|
||||
#define FC_UPSCALE 1500
|
||||
|
||||
// op-specific constants
|
||||
#define OP_FLASH_ATTN_EXT_NQPSG 8
|
||||
@@ -890,6 +891,7 @@ typedef struct {
|
||||
float sf1;
|
||||
float sf2;
|
||||
float sf3;
|
||||
float poffs;
|
||||
} ggml_metal_kargs_upscale;
|
||||
|
||||
typedef struct {
|
||||
|
||||
@@ -1963,6 +1963,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
(
|
||||
op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
|
||||
op->src[0]->type == GGML_TYPE_F16 ||
|
||||
op->src[0]->type == GGML_TYPE_BF16 ||
|
||||
op->src[0]->type == GGML_TYPE_Q4_0 ||
|
||||
op->src[0]->type == GGML_TYPE_Q4_1 ||
|
||||
op->src[0]->type == GGML_TYPE_Q5_0 ||
|
||||
@@ -1977,6 +1978,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
op->src[0]->type == GGML_TYPE_Q4_K ||
|
||||
op->src[0]->type == GGML_TYPE_Q5_K ||
|
||||
op->src[0]->type == GGML_TYPE_Q6_K ||
|
||||
op->src[0]->type == GGML_TYPE_Q2_K ||
|
||||
op->src[0]->type == GGML_TYPE_Q3_K ||
|
||||
false) && (ne11 >= 4 && ne11 <= 8)
|
||||
)
|
||||
)
|
||||
@@ -3729,32 +3732,43 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const float sf0 = (float)ne0/op->src[0]->ne[0];
|
||||
const float sf1 = (float)ne1/op->src[0]->ne[1];
|
||||
const float sf2 = (float)ne2/op->src[0]->ne[2];
|
||||
const float sf3 = (float)ne3/op->src[0]->ne[3];
|
||||
float sf0 = (float)ne0/op->src[0]->ne[0];
|
||||
float sf1 = (float)ne1/op->src[0]->ne[1];
|
||||
float sf2 = (float)ne2/op->src[0]->ne[2];
|
||||
float sf3 = (float)ne3/op->src[0]->ne[3];
|
||||
|
||||
const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
|
||||
|
||||
float poffs = 0.5f;
|
||||
|
||||
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
||||
poffs = 0.0f;
|
||||
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
|
||||
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
|
||||
}
|
||||
|
||||
ggml_metal_kargs_upscale args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
/*.sf0 =*/ sf0,
|
||||
/*.sf1 =*/ sf1,
|
||||
/*.sf2 =*/ sf2,
|
||||
/*.sf3 =*/ sf3
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
/*.sf0 =*/ sf0,
|
||||
/*.sf1 =*/ sf1,
|
||||
/*.sf2 =*/ sf2,
|
||||
/*.sf3 =*/ sf3,
|
||||
/*.poffs =*/ poffs,
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
|
||||
|
||||
@@ -3481,6 +3481,13 @@ template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4
|
||||
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>;
|
||||
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, bfloat4, 4, dequantize_bf16_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, bfloat4, 4, dequantize_bf16_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, bfloat4, 4, dequantize_bf16_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>;
|
||||
#endif
|
||||
|
||||
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||
@@ -3531,6 +3538,16 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4
|
||||
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q2_K, 256, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q2_K, 256, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q2_K, 256, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q2_K, 256, dequantize_q2_K>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q3_K, 256, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q3_K, 256, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q3_K, 256, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q3_K, 256, dequantize_q3_K>;
|
||||
|
||||
template<typename T0, typename T1, short NR0, typename args_t>
|
||||
void kernel_mul_mv_t_t_impl(
|
||||
args_t args,
|
||||
@@ -4530,7 +4547,9 @@ kernel void kernel_conv_transpose_2d<half>(
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
kernel void kernel_upscale_f32(
|
||||
constant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]];
|
||||
|
||||
kernel void kernel_upscale_nearest_f32(
|
||||
constant ggml_metal_kargs_upscale & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
@@ -4556,6 +4575,156 @@ kernel void kernel_upscale_f32(
|
||||
}
|
||||
}
|
||||
|
||||
static inline float bilinear_tri(float x) {
|
||||
return MAX(0.0f, 1.0f - fabs(x));
|
||||
}
|
||||
|
||||
kernel void kernel_upscale_bilinear_f32(
|
||||
constant ggml_metal_kargs_upscale & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i1 = tgpig.x;
|
||||
|
||||
const int64_t i03 = i3 / args.sf3;
|
||||
const int64_t i02 = i2 / args.sf2;
|
||||
|
||||
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
|
||||
const int64_t i01 = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01)));
|
||||
const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1));
|
||||
const float fd1 = MAX(0.0f, MIN(1.0f, f01 - (float)i01));
|
||||
|
||||
src0 += i03*args.nb03 + i02*args.nb02;
|
||||
|
||||
device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
|
||||
|
||||
if (FC_upscale_aa) {
|
||||
const float support0 = MAX(1.0f, 1.0f / args.sf0);
|
||||
const float invscale0 = 1.0f / support0;
|
||||
const float support1 = MAX(1.0f, 1.0f / args.sf1);
|
||||
const float invscale1 = 1.0f / support1;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
|
||||
|
||||
int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs));
|
||||
int64_t x_max = MIN(args.ne00, (int64_t)ceil (f00 + support0 + args.poffs));
|
||||
|
||||
int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs));
|
||||
int64_t y_max = MIN(args.ne01, (int64_t)ceil (f01 + support1 + args.poffs));
|
||||
|
||||
float sum = 0.0f;
|
||||
float wsum = 0.0f;
|
||||
|
||||
for (int64_t sy = y_min; sy < y_max; ++sy) {
|
||||
const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1);
|
||||
for (int64_t sx = x_min; sx < x_max; ++sx) {
|
||||
const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
|
||||
const float w = wx * wy;
|
||||
const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
|
||||
sum += (*src_ptr) * w;
|
||||
wsum += w;
|
||||
}
|
||||
}
|
||||
|
||||
const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f;
|
||||
dst_ptr[i0] = v;
|
||||
}
|
||||
} else {
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
|
||||
const int64_t i00 = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00)));
|
||||
const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1));
|
||||
const float fd0 = MAX(0.0f, MIN(1.0f, f00 - (float)i00));
|
||||
|
||||
device const float * src00 = (device const float *)(src0 + i01*args.nb01 + i00*args.nb00);
|
||||
device const float * src10 = (device const float *)(src0 + i01*args.nb01 + i00p*args.nb00);
|
||||
device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00);
|
||||
device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00);
|
||||
|
||||
const float v =
|
||||
(*src00) * (1.0f - fd0) * (1.0f - fd1) +
|
||||
(*src10) * fd0 * (1.0f - fd1) +
|
||||
(*src01) * (1.0f - fd0) * fd1 +
|
||||
(*src11) * fd0 * fd1;
|
||||
|
||||
dst_ptr[i0] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline float bicubic_weight1(float x) {
|
||||
const float a = -0.75f;
|
||||
return ((a + 2) * x - (a + 3)) * x * x + 1;
|
||||
}
|
||||
|
||||
static inline float bicubic_weight2(float x) {
|
||||
const float a = -0.75f;
|
||||
return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
|
||||
}
|
||||
|
||||
kernel void kernel_upscale_bicubic_f32(
|
||||
constant ggml_metal_kargs_upscale & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i1 = tgpig.x;
|
||||
|
||||
const int64_t i03 = i3 / args.sf3;
|
||||
const int64_t i02 = i2 / args.sf2;
|
||||
|
||||
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
|
||||
const int64_t i01 = (int64_t)floor(f01);
|
||||
const float fd1 = f01 - (float)i01;
|
||||
|
||||
const float w_y0 = bicubic_weight2(fd1 + 1.0f);
|
||||
const float w_y1 = bicubic_weight1(fd1);
|
||||
const float w_y2 = bicubic_weight1(1.0f - fd1);
|
||||
const float w_y3 = bicubic_weight2(2.0f - fd1);
|
||||
|
||||
const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02;
|
||||
|
||||
device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1);
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
|
||||
const int64_t i00 = (int64_t)floor(f00);
|
||||
const float fd0 = f00 - (float)i00;
|
||||
|
||||
const float w_x0 = bicubic_weight2(fd0 + 1.0f);
|
||||
const float w_x1 = bicubic_weight1(fd0);
|
||||
const float w_x2 = bicubic_weight1(1.0f - fd0);
|
||||
const float w_x3 = bicubic_weight2(2.0f - fd0);
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int dy = -1; dy <= 2; ++dy) {
|
||||
const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy));
|
||||
const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3;
|
||||
|
||||
for (int dx = -1; dx <= 2; ++dx) {
|
||||
const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
|
||||
const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
|
||||
|
||||
const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
|
||||
sum += (*src_ptr) * wx * wy;
|
||||
}
|
||||
}
|
||||
|
||||
dst_ptr[i0] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_pad_f32(
|
||||
constant ggml_metal_kargs_pad & args,
|
||||
device const char * src0,
|
||||
|
||||
@@ -25,6 +25,11 @@ ggml_add_backend_library(ggml-sycl
|
||||
|
||||
file(GLOB GGML_HEADERS_SYCL "*.hpp")
|
||||
file(GLOB GGML_SOURCES_SYCL "*.cpp")
|
||||
file(GLOB SRCS "template-instances/fattn-tile*.cpp")
|
||||
list(APPEND GGML_SOURCES_SYCL ${SRCS})
|
||||
file(GLOB SRCS "template-instances/fattn-vec*.cpp")
|
||||
list(APPEND GGML_SOURCES_SYCL ${SRCS})
|
||||
|
||||
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
|
||||
|
||||
if (WIN32)
|
||||
@@ -145,6 +150,7 @@ else()
|
||||
endif()
|
||||
|
||||
if (GGML_SYCL_GRAPH)
|
||||
message(STATUS "find GGML_SYCL_GRAPH")
|
||||
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "dequantize.hpp"
|
||||
#include "dmmv.hpp"
|
||||
#include "element_wise.hpp"
|
||||
#include "fattn.hpp"
|
||||
#include "gla.hpp"
|
||||
#include "im2col.hpp"
|
||||
#include "mmq.hpp"
|
||||
|
||||
@@ -19,10 +19,13 @@
|
||||
#include <string>
|
||||
|
||||
#include "dpct/helper.hpp"
|
||||
#include "ggml.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-sycl.h"
|
||||
#include "presets.hpp"
|
||||
#include "sycl_hw.hpp"
|
||||
|
||||
namespace syclexp = sycl::ext::oneapi::experimental;
|
||||
|
||||
#if GGML_SYCL_DNNL
|
||||
#include "dnnl.hpp"
|
||||
@@ -31,6 +34,9 @@
|
||||
|
||||
#define GGML_COMMON_DECL_SYCL
|
||||
#define GGML_COMMON_IMPL_SYCL
|
||||
#define SYCL_FLASH_ATTN //remove it to disable FLASH_ATTENTION in building.
|
||||
#define SYCL_FAST_FP16 //don't change. remove it will break fattn-tile.hpp building
|
||||
|
||||
/* suppress warning spam */
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wnested-anon-types"
|
||||
@@ -45,6 +51,8 @@ void ggml_sycl_host_free(void* ptr);
|
||||
extern int g_ggml_sycl_debug;
|
||||
extern int g_ggml_sycl_disable_optimize;
|
||||
extern int g_ggml_sycl_prioritize_dmmv;
|
||||
extern int g_ggml_sycl_enable_flash_attention;
|
||||
|
||||
|
||||
#if defined(__clang__) && __has_builtin(__builtin_expect)
|
||||
// Hint the optimizer to pipeline the more likely following instruction in branches
|
||||
@@ -170,6 +178,10 @@ static size_t g_scratch_offset = 0;
|
||||
|
||||
int get_current_device_id();
|
||||
|
||||
inline int ggml_sycl_get_device() {
|
||||
return get_current_device_id();
|
||||
}
|
||||
|
||||
inline dpct::err0 ggml_sycl_set_device(const int device) try {
|
||||
int current_device_id;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id()));
|
||||
@@ -194,11 +206,14 @@ struct optimize_feature {
|
||||
};
|
||||
|
||||
struct sycl_device_info {
|
||||
int cc; // compute capability
|
||||
int cc; // compute capability
|
||||
int nsm; // number of streaming multiprocessors (CUDA) maps to the maximum
|
||||
// number of compute units on a SYCL device.
|
||||
// size_t smpb; // max. shared memory per block
|
||||
size_t smpbo; // max. shared memory per block (with opt-in)
|
||||
int warp_size; // max sub_group_size of SYCL
|
||||
int max_wg_per_cu; // max work groups per compute unit - refer to
|
||||
// cudaOccupancyMaxActiveBlocksPerMultiprocessor
|
||||
bool vmm; // virtual memory support
|
||||
size_t total_vram;
|
||||
//sycl_hw_info hw_info; \\ device id and aarch, currently not used
|
||||
@@ -435,13 +450,15 @@ warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) {
|
||||
return a;
|
||||
}
|
||||
|
||||
template <int width = WARP_SIZE>
|
||||
/* use WARP_SIZE or WARP_32_SIZE*/
|
||||
template <int width>
|
||||
static __dpct_inline__ int warp_reduce_sum(int x) {
|
||||
return sycl::reduce_over_group(
|
||||
sycl::ext::oneapi::this_work_item::get_sub_group(), x, sycl::plus<>());
|
||||
}
|
||||
|
||||
template <int width = WARP_SIZE>
|
||||
/* use WARP_SIZE or WARP_32_SIZE*/
|
||||
template <int width>
|
||||
static __dpct_inline__ float warp_reduce_sum(float x) {
|
||||
#pragma unroll
|
||||
for (int offset = width / 2; offset > 0; offset >>= 1) {
|
||||
@@ -451,7 +468,19 @@ static __dpct_inline__ float warp_reduce_sum(float x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
template <int width = WARP_SIZE>
|
||||
/* use WARP_SIZE or WARP_32_SIZE*/
|
||||
template <int width>
|
||||
static __dpct_inline__ float warp_reduce_sum(float x, const sycl::nd_item<3>& item_ct1) {
|
||||
#pragma unroll
|
||||
for (int offset = width / 2; offset > 0; offset >>= 1) {
|
||||
x += dpct::permute_sub_group_by_xor(
|
||||
item_ct1.get_sub_group(), x, offset);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
/* use WARP_SIZE or WARP_32_SIZE*/
|
||||
template <int width>
|
||||
static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) {
|
||||
#pragma unroll
|
||||
for (int offset = width / 2; offset > 0; offset >>= 1) {
|
||||
@@ -465,7 +494,8 @@ static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
template <int width = WARP_SIZE>
|
||||
/* use WARP_SIZE or WARP_32_SIZE*/
|
||||
template <int width>
|
||||
static __dpct_inline__ sycl::half2 warp_reduce_sum(sycl::half2 a) {
|
||||
#pragma unroll
|
||||
for (int offset = width / 2; offset > 0; offset >>= 1) {
|
||||
@@ -481,7 +511,52 @@ static constexpr int ggml_sycl_get_physical_warp_size() {
|
||||
return WARP_SIZE;
|
||||
}
|
||||
|
||||
template <int width = WARP_SIZE>
|
||||
/* use WARP_SIZE or WARP_32_SIZE*/
|
||||
template <int width>
|
||||
static __dpct_inline__ int warp_reduce_all(int x) {
|
||||
if (width == ggml_sycl_get_physical_warp_size()) {
|
||||
return sycl::all_of_group(
|
||||
sycl::ext::oneapi::this_work_item::get_sub_group(),
|
||||
(~0xffffffff &
|
||||
(0x1 << sycl::ext::oneapi::this_work_item::get_sub_group()
|
||||
.get_local_linear_id())) ||
|
||||
x);
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int offset = width / 2; offset > 0; offset >>= 1) {
|
||||
x = dpct::permute_sub_group_by_xor(
|
||||
sycl::ext::oneapi::this_work_item::get_sub_group(), x,
|
||||
offset, width) &&
|
||||
x;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
/* use WARP_SIZE or WARP_32_SIZE*/
|
||||
template <int width>
|
||||
static __dpct_inline__ int warp_reduce_any(int x) {
|
||||
if (width == ggml_sycl_get_physical_warp_size()) {
|
||||
return sycl::any_of_group(
|
||||
sycl::ext::oneapi::this_work_item::get_sub_group(),
|
||||
(0xffffffff &
|
||||
(0x1 << sycl::ext::oneapi::this_work_item::get_sub_group()
|
||||
.get_local_linear_id())) &&
|
||||
x);
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int offset = width / 2; offset > 0; offset >>= 1) {
|
||||
x = dpct::permute_sub_group_by_xor(
|
||||
sycl::ext::oneapi::this_work_item::get_sub_group(), x,
|
||||
offset, width) ||
|
||||
x;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
/* use WARP_SIZE or WARP_32_SIZE*/
|
||||
template <int width>
|
||||
static __dpct_inline__ float warp_reduce_max(float x) {
|
||||
#pragma unroll
|
||||
for (int offset = width / 2; offset > 0; offset >>= 1) {
|
||||
@@ -629,6 +704,42 @@ static const sycl::uint3 init_fastdiv_values(uint32_t d) {
|
||||
return sycl::uint3(mp, L, d);
|
||||
}
|
||||
|
||||
// Maximum number of bytes that can be copied in a single instruction.
|
||||
// Set by test result.
|
||||
static constexpr int ggml_sycl_get_max_cpy_bytes() {
|
||||
return 16;
|
||||
}
|
||||
|
||||
// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes.
|
||||
template <int nbytes, int alignment = 0>
|
||||
static __dpct_inline__ void ggml_sycl_memcpy_1(void * dst, const void * src) {
|
||||
if constexpr (alignment != 0) {
|
||||
static_assert(nbytes % alignment == 0, "bad alignment");
|
||||
}
|
||||
constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < nbytes/nb_per_cpy; ++i) {
|
||||
if constexpr (nb_per_cpy == 1) {
|
||||
((char *) dst)[i] = ((const char *) src)[i];
|
||||
} else if constexpr (nb_per_cpy == 2) {
|
||||
((short *) dst)[i] = ((const short *) src)[i];
|
||||
} else if constexpr (nb_per_cpy == 4) {
|
||||
((int *) dst)[i] = ((const int *) src)[i];
|
||||
} else if constexpr (nb_per_cpy == 8) {
|
||||
((sycl::int2 *) dst)[i] = ((const sycl::int2 *) src)[i];
|
||||
} else if constexpr (nb_per_cpy == 16) {
|
||||
((sycl::int4 *) dst)[i] = ((const sycl::int4 *) src)[i];
|
||||
} else {
|
||||
static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
sycl::half2 __dpct_inline__ make_half2( T x, T y) {
|
||||
sycl::half2 res(static_cast<sycl::half>(x),static_cast<sycl::half>(y));
|
||||
return res;
|
||||
}
|
||||
|
||||
static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_values) {
|
||||
const uint32_t hi = sycl::mul_hi<unsigned>(n, fastdiv_values.x());
|
||||
@@ -636,6 +747,17 @@ static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_va
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
sycl::float2 __dpct_inline__ make_float2( T x, T y) {
|
||||
sycl::float2 res(static_cast<float>(x),static_cast<float>(y));
|
||||
return res;
|
||||
}
|
||||
|
||||
sycl::float2 __dpct_inline__ __half22float2(sycl::half2 &H) {
|
||||
sycl::float2 float2_value(static_cast<float>(H.x()), static_cast<float>(H.y()));
|
||||
return float2_value;
|
||||
}
|
||||
|
||||
static __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3 fastdiv_values) {
|
||||
const uint32_t div_val = fastdiv(n, fastdiv_values);
|
||||
const uint32_t mod_val = n - div_val * fastdiv_values.z();
|
||||
@@ -659,5 +781,97 @@ static __dpct_inline__ float ggml_sycl_e8m0_to_fp32(uint8_t x) {
|
||||
return result;
|
||||
}
|
||||
|
||||
sycl::float2 __dpct_inline__ __half22float2(const sycl::half2 &H) {
|
||||
sycl::float2 float2_value(static_cast<float>(H.x()), static_cast<float>(H.y()));
|
||||
return float2_value;
|
||||
}
|
||||
|
||||
float __dpct_inline__ __half2float(sycl::half H) {
|
||||
return static_cast<float>(H);
|
||||
}
|
||||
|
||||
static __dpct_inline__ void ggml_sycl_mad(float & acc, const float v, const float u) {
|
||||
acc += v*u;
|
||||
}
|
||||
|
||||
static __dpct_inline__ void ggml_sycl_mad(float & acc, const sycl::float2 v, const sycl::float2 u) {
|
||||
acc += v.x() * u.x();
|
||||
acc += v.y() * u.y();
|
||||
}
|
||||
|
||||
static __dpct_inline__ void ggml_sycl_mad(float & acc, const sycl::half2 v, const sycl::half2 u) {
|
||||
#ifdef GGML_SYCL_F16
|
||||
const sycl::float2 tmp = (v * u).template convert<float, sycl::rounding_mode::automatic>();
|
||||
acc += tmp.x() + tmp.y();
|
||||
#else
|
||||
const sycl::float2 tmpv = __half22float2(v);
|
||||
const sycl::float2 tmpu = __half22float2(u);
|
||||
acc += tmpv.x() * tmpu.x();
|
||||
acc += tmpv.y() * tmpu.y();
|
||||
#endif // GGML_SYCL_F16
|
||||
}
|
||||
|
||||
static __dpct_inline__ void ggml_sycl_mad(sycl::half2 & acc, const sycl::half2 v, const sycl::half2 u) {
|
||||
#ifdef GGML_SYCL_F16
|
||||
acc += v*u;
|
||||
#else
|
||||
const sycl::float2 tmpv = __half22float2(v);
|
||||
const sycl::float2 tmpu = __half22float2(u);
|
||||
sycl::float2 tmpacc = __half22float2(acc);
|
||||
// tmpacc.x += tmpv.x() * tmpu.x();
|
||||
// tmpacc.y += tmpv.y() * tmpu.y();
|
||||
sycl::float2 tmp1(tmpacc.x() + tmpv.x() * tmpu.x(), tmpacc.y() + tmpv.y() * tmpu.y());
|
||||
acc = make_half2(tmp1.x(), tmp1.y());
|
||||
#endif // GGML_SYCL_F16
|
||||
}
|
||||
|
||||
template <int n>
|
||||
struct ggml_sycl_unroll {
|
||||
template <typename Func, typename... Args>
|
||||
void operator()(const Func & f, Args... args) const {
|
||||
f(n - 1, args...);
|
||||
ggml_sycl_unroll<n - 1>{}(f, args...);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ggml_sycl_unroll<1> {
|
||||
template <typename Func, typename... Args>
|
||||
void operator()(const Func & f, Args... args) const {
|
||||
f(0, args...);
|
||||
}
|
||||
};
|
||||
|
||||
static __dpct_inline__ sycl::half2 ggml_sycl_hmax2(const sycl::half2 a, const sycl::half2 b) {
|
||||
sycl::half2 ret;
|
||||
reinterpret_cast<sycl::half &>(ret.x()) =
|
||||
sycl::vec<float, 1>(sycl::fmax(a[0], b[0])).convert<sycl::half, sycl::rounding_mode::automatic>()[0];
|
||||
reinterpret_cast<sycl::half &>(ret.y()) =
|
||||
sycl::vec<float, 1>(sycl::fmax(a[1], b[1])).convert<sycl::half, sycl::rounding_mode::automatic>()[0];
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __dpct_inline__ sycl::half ggml_sycl_hmax(const sycl::half a, const sycl::half b) {
|
||||
return sycl::vec<float, 1>(
|
||||
sycl::fmax(sycl::vec<sycl::half, 1>(a).convert<float, sycl::rounding_mode::automatic>()[0],
|
||||
sycl::vec<sycl::half, 1>(b).convert<float, sycl::rounding_mode::automatic>()[0]))
|
||||
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
|
||||
}
|
||||
|
||||
static __dpct_inline__ uint32_t __hgt2_mask(const sycl::half2 a, const sycl::half2 b) {
|
||||
const uint32_t mask_low = 0x0000FFFF * (float(a[0]) > float(b[0]));
|
||||
const uint32_t mask_high = 0xFFFF0000 * (float(a[1]) > float(b[1]));
|
||||
return mask_low | mask_high;
|
||||
}
|
||||
|
||||
static __dpct_inline__ uint32_t fastmodulo(uint32_t n, const sycl::uint3 fastdiv_values) {
|
||||
// expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
|
||||
return n - fastdiv(n, fastdiv_values) * fastdiv_values.z();
|
||||
}
|
||||
|
||||
static bool fast_fp16_available(const int cc) {
|
||||
GGML_UNUSED(cc);
|
||||
return true; //Intel GPUs always support FP16.
|
||||
}
|
||||
|
||||
#endif // GGML_SYCL_COMMON_HPP
|
||||
|
||||
@@ -482,6 +482,63 @@ static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t
|
||||
});
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static void dequantize_block_nc(const void * __restrict__ vx, dst_t * __restrict__ y,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int64_t i00 = 2 * (int64_t(item_ct1.get_local_range(2)) * item_ct1.get_group(2) + item_ct1.get_local_id(2));
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i01 = item_ct1.get_group(1);
|
||||
const int64_t i02 = item_ct1.get_group(0) % ne02;
|
||||
const int64_t i03 = item_ct1.get_group(0) / ne02;
|
||||
|
||||
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
|
||||
|
||||
const int64_t ib = ibx0 + i00/qk; // block index
|
||||
const int64_t iqs = (i00%qk)/qr; // quant index
|
||||
const int64_t iybs = i00 - i00%qk; // y block start index
|
||||
const int64_t y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
// dequantize
|
||||
#ifdef GGML_SYCL_F16
|
||||
sycl::half2 v;
|
||||
#else
|
||||
sycl::float2 v;
|
||||
#endif
|
||||
|
||||
dequantize_kernel(vx, ib, iqs, v);
|
||||
|
||||
const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
|
||||
y[iy0 + 0] = ggml_sycl_cast<dst_t>(v.x());
|
||||
y[iy0 + y_offset] = ggml_sycl_cast<dst_t>(v.y());
|
||||
}
|
||||
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static void dequantize_block_nc_sycl(const void * vx,
|
||||
dst_t * y,
|
||||
const int64_t ne00,
|
||||
const int64_t ne01,
|
||||
const int64_t ne02,
|
||||
const int64_t ne03,
|
||||
const int64_t s01,
|
||||
const int64_t s02,
|
||||
const int64_t s03,
|
||||
dpct::queue_ptr stream) {
|
||||
const dpct::dim3 num_blocks((ne00 + 2 * SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2 * SYCL_DEQUANTIZE_BLOCK_SIZE), ne01,
|
||||
ne02 * ne03);
|
||||
stream->parallel_for(sycl::nd_range<3>(num_blocks * sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
GGML_UNUSED(item_ct1);
|
||||
dequantize_block_nc<qk, qr, dequantize_kernel>(vx, y, ne00, ne01, ne02, s01, s02, s03);
|
||||
});
|
||||
}
|
||||
template <typename src_t, typename dst_t>
|
||||
static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
|
||||
const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
@@ -662,7 +719,8 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
|
||||
}
|
||||
}
|
||||
|
||||
to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
|
||||
|
||||
to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
return convert_unary_nc_sycl<float>;
|
||||
@@ -670,6 +728,16 @@ to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
|
||||
case GGML_TYPE_BF16:
|
||||
return convert_unary_nc_sycl<sycl::ext::oneapi::bfloat16>;
|
||||
#endif
|
||||
case GGML_TYPE_Q4_0:
|
||||
return dequantize_block_nc_sycl<QK4_0, QR4_0, dequantize_q4_0>;
|
||||
case GGML_TYPE_Q4_1:
|
||||
return dequantize_block_nc_sycl<QK4_1, QR4_1, dequantize_q4_1>;
|
||||
case GGML_TYPE_Q5_0:
|
||||
return dequantize_block_nc_sycl<QK5_0, QR5_0, dequantize_q5_0>;
|
||||
case GGML_TYPE_Q5_1:
|
||||
return dequantize_block_nc_sycl<QK5_1, QR5_1, dequantize_q5_1>;
|
||||
case GGML_TYPE_Q8_0:
|
||||
return dequantize_block_nc_sycl<QK8_0, QR8_0, dequantize_q8_0>;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -29,6 +29,21 @@ using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne0
|
||||
int64_t s01, int64_t s02, int64_t s03, dpct::queue_ptr queue);
|
||||
|
||||
typedef to_t_nc_sycl_t<sycl::half> to_fp16_nc_sycl_t;
|
||||
to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type);
|
||||
to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type);
|
||||
|
||||
template<typename dst_t, typename src_t>
|
||||
inline dst_t ggml_sycl_cast(src_t x) {
|
||||
if constexpr (std::is_same_v<dst_t, src_t>) {
|
||||
return x;
|
||||
} else if constexpr (std::is_same_v<dst_t, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::bfloat16(float(x));
|
||||
} else if constexpr (std::is_same_v<src_t, sycl::ext::oneapi::bfloat16>) {
|
||||
return static_cast<float>(x);
|
||||
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
|
||||
return int32_t(x);
|
||||
} else {
|
||||
return float(x);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // GGML_SYCL_CONVERT_HPP
|
||||
|
||||
@@ -18,7 +18,7 @@ static void count_equal(const T *__restrict__ x, const T *__restrict__ y,
|
||||
nequal += xi == yi;
|
||||
}
|
||||
|
||||
nequal = warp_reduce_sum(nequal);
|
||||
nequal = warp_reduce_sum<WARP_SIZE>(nequal);
|
||||
|
||||
if (item_ct1.get_local_id(2) != 0) {
|
||||
return;
|
||||
|
||||
@@ -2997,6 +2997,778 @@ namespace dpct
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <int n_nondefault_params, int n_default_params, typename T>
|
||||
class args_selector;
|
||||
|
||||
/// args_selector is a helper class for extracting arguments from an
|
||||
/// array of pointers to arguments or buffer of arguments to pass to a
|
||||
/// kernel function.
|
||||
///
|
||||
/// \param R(Ts...) The type of the kernel
|
||||
/// \param n_nondefault_params The number of nondefault parameters of the
|
||||
/// kernel (excluding parameters that like sycl::nd_item, etc.) \param
|
||||
/// n_default_params The number of default parameters of the kernel
|
||||
///
|
||||
/// Example usage:
|
||||
/// With the following kernel:
|
||||
/// void foo(sycl::float2 *x, int n, sycl::nd_item<3> item_ct1, float
|
||||
/// f=.1) {}
|
||||
/// and with the declaration:
|
||||
/// args_selector<2, 1, decltype(foo)> selector(kernelParams, extra);
|
||||
/// we have:
|
||||
/// selector.get<0>() returns a reference to sycl::float*,
|
||||
/// selector.get<1>() returns a reference to int,
|
||||
/// selector.get<2>() returns a reference to float
|
||||
template <int n_nondefault_params, int n_default_params, typename R,
|
||||
typename... Ts>
|
||||
class args_selector<n_nondefault_params, n_default_params, R(Ts...)> {
|
||||
private:
|
||||
void **kernel_params;
|
||||
char *args_buffer;
|
||||
|
||||
template <int i> static constexpr int account_for_default_params() {
|
||||
constexpr int n_total_params = sizeof...(Ts);
|
||||
if constexpr (i >= n_nondefault_params) {
|
||||
return n_total_params - n_default_params +
|
||||
(i - n_nondefault_params);
|
||||
} else {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/// Get the type of the ith argument of R(Ts...)
|
||||
/// \param [in] i Index of parameter to get
|
||||
/// \returns Type of ith parameter
|
||||
template <int i>
|
||||
using arg_type = std::tuple_element_t<account_for_default_params<i>(),
|
||||
std::tuple<Ts...>>;
|
||||
static constexpr int params_num = sizeof...(Ts);
|
||||
|
||||
private:
|
||||
template <int i> static constexpr int get_offset() {
|
||||
if constexpr (i == 0) {
|
||||
// we can assume args_buffer is properly aligned to the
|
||||
// first argument
|
||||
return 0;
|
||||
} else {
|
||||
constexpr int prev_off = get_offset<i - 1>();
|
||||
constexpr int prev_past_end =
|
||||
prev_off + sizeof(arg_type<i - 1>);
|
||||
using T = arg_type<i>;
|
||||
// is the past-the-end of the i-1st element properly aligned
|
||||
// with the ith element's alignment?
|
||||
if constexpr (prev_past_end % alignof(T) == 0) {
|
||||
return prev_past_end;
|
||||
}
|
||||
// otherwise bump prev_past_end to match alignment
|
||||
else {
|
||||
return prev_past_end +
|
||||
(alignof(T) - (prev_past_end % alignof(T)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static char *get_args_buffer(void **extra) {
|
||||
if (!extra)
|
||||
return nullptr;
|
||||
for (; (std::size_t)*extra != 0; ++extra) {
|
||||
if ((std::size_t)*extra == 1) {
|
||||
return static_cast<char *>(*(extra + 1));
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
public:
|
||||
/// If kernel_params is nonnull, then args_selector will
|
||||
/// extract arguments from kernel_params. Otherwise, it
|
||||
/// will extract them from extra.
|
||||
/// \param [in] kernel_params Array of pointers to arguments
|
||||
/// a or null pointer.
|
||||
/// \param [in] extra Array containing pointer to argument buffer.
|
||||
args_selector(void **kernel_params, void **extra)
|
||||
: kernel_params(kernel_params),
|
||||
args_buffer(get_args_buffer(extra)) {}
|
||||
|
||||
/// Get a reference to the ith argument extracted from kernel_params
|
||||
/// or extra.
|
||||
/// \param [in] i Index of argument to get
|
||||
/// \returns Reference to the ith argument
|
||||
template <int i> arg_type<i> &get() {
|
||||
if (kernel_params) {
|
||||
return *static_cast<arg_type<i> *>(kernel_params[i]);
|
||||
} else {
|
||||
return *reinterpret_cast<arg_type<i> *>(args_buffer +
|
||||
get_offset<i>());
|
||||
}
|
||||
}
|
||||
}; // COPY from DPCT head file
|
||||
// /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp
|
||||
|
||||
/// Utility class for launching SYCL kernels through kernel
|
||||
/// function wrapper.
|
||||
/// For example:
|
||||
/// A SYCL kernel function:
|
||||
/// void kernel_func(int *ptr, sycl::nd_item<3> item);
|
||||
/// Kernel function wrapper:
|
||||
/// void kernel_func_wrapper(int *ptr) {
|
||||
/// sycl::queue queue = *dpct::kernel_launcher::_que;
|
||||
/// unsigned int localMemSize = dpct::kernel_launcher::_local_mem_size;
|
||||
/// sycl::nd_range<3> nr = dpct::kernel_launcher::_nr;
|
||||
/// queue.parallel_for(
|
||||
/// nr,
|
||||
/// [=](sycl::nd_item<3> item_ct1) {
|
||||
/// kernel_func(ptr, item_ct1);
|
||||
/// });
|
||||
/// }
|
||||
/// Then launch the kernel through wrapper like:
|
||||
/// typedef void(*fpt)(int *);
|
||||
/// fpt fp = kernel_func_wrapper;
|
||||
/// dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), 0, 0,
|
||||
/// device_ptr);
|
||||
/// If the origin function type is erased, then need to register it first:
|
||||
/// void *fp = (void *)wrapper_register(&kernel_func_wrapper).get();
|
||||
/// dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), args,
|
||||
/// 0, 0);
|
||||
class kernel_launcher {
|
||||
template <typename FuncT, typename ArgSelector, std::size_t... Index>
|
||||
static void launch_helper(FuncT &&func, ArgSelector &selector,
|
||||
std::index_sequence<Index...>) {
|
||||
func(selector.template get<Index>()...);
|
||||
}
|
||||
static void set_execution_config(dim3 group_range, dim3 local_range,
|
||||
unsigned int local_mem_size,
|
||||
queue_ptr que) {
|
||||
if (que) {
|
||||
_que = que;
|
||||
} else {
|
||||
_que = &get_default_queue();
|
||||
}
|
||||
_nr = sycl::nd_range<3>(
|
||||
static_cast<sycl::range<3>>(group_range * local_range),
|
||||
static_cast<sycl::range<3>>(local_range));
|
||||
_local_mem_size = local_mem_size;
|
||||
|
||||
|
||||
};
|
||||
static inline std::mutex kernel_function_ptr_map_mutex;
|
||||
|
||||
public:
|
||||
/// Variables for storing execution configuration.
|
||||
static inline thread_local sycl::queue *_que = nullptr;
|
||||
static inline thread_local sycl::nd_range<3> _nr = sycl::nd_range<3>();
|
||||
static inline thread_local unsigned int _local_mem_size = 0;
|
||||
/// Map for retrieving launchable functor from a raw pointer.
|
||||
static inline std::map<
|
||||
const void *,
|
||||
std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>>
|
||||
kernel_function_ptr_map = {};
|
||||
|
||||
/// Registers a kernel function pointer with a corresponding launchable
|
||||
/// functor.
|
||||
/// \param [in] func Pointer to the kernel function.
|
||||
/// \param [in] launcher Functor to handle kernel invocation.
|
||||
static void register_kernel_ptr(
|
||||
const void *func,
|
||||
std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>
|
||||
launcher) {
|
||||
std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex);
|
||||
kernel_function_ptr_map[func] = std::move(launcher);
|
||||
}
|
||||
/// Launches a kernel function with arguments provided directly through
|
||||
/// kernel function wrapper.
|
||||
/// \tparam FuncT Type of the kernel function wrapper.
|
||||
/// \tparam ArgsT Types of kernel arguments.
|
||||
/// \param [in] func Pointer to the kernel function wrapper.
|
||||
/// \param [in] group_range SYCL group range.
|
||||
/// \param [in] local_range SYCL local range.
|
||||
/// \param [in] local_mem_size The size of local memory required by the
|
||||
/// kernel function. \param [in] que SYCL queue used to execute kernel.
|
||||
/// \param [in] args Kernel arguments.
|
||||
template <typename FuncT, typename... ArgsT>
|
||||
static std::enable_if_t<std::is_invocable_v<FuncT *, ArgsT...>, void>
|
||||
launch(FuncT *func, dim3 group_range, dim3 local_range,
|
||||
unsigned int local_mem_size, queue_ptr que, ArgsT... args) {
|
||||
set_execution_config(group_range, local_range, local_mem_size, que);
|
||||
func(args...);
|
||||
}
|
||||
/// Launches a kernel function through registered kernel function
|
||||
/// wrapper. \param [in] func Pointer to the registered kernel function
|
||||
/// wrapper. \param [in] group_range SYCL group range. \param [in]
|
||||
/// local_range SYCL local range. \param [in] args Array of pointers to
|
||||
/// kernel arguments. \param [in] local_mem_size The size of local
|
||||
/// memory required by the kernel function. \param [in] que SYCL queue
|
||||
/// used to execute kernel.
|
||||
static void launch(const void *func, dim3 group_range, dim3 local_range,
|
||||
void **args, unsigned int local_mem_size,
|
||||
queue_ptr que) {
|
||||
std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex);
|
||||
auto Iter = kernel_function_ptr_map.find(func);
|
||||
if (Iter == kernel_function_ptr_map.end()) {
|
||||
throw std::runtime_error("dpct::launch() : no registered "
|
||||
"kernel function wrapper found.");
|
||||
}
|
||||
(Iter->second)(group_range, local_range, args, local_mem_size, que);
|
||||
}
|
||||
/// Launches a kernel function with packed arguments through kernel
|
||||
/// function wrapper.
|
||||
/// \tparam FuncT Type of the kernel function wrapper.
|
||||
/// \param [in] func Pointer to the kernel function wrapper.
|
||||
/// \param [in] group_range SYCL group range.
|
||||
/// \param [in] local_range SYCL local range.
|
||||
/// \param [in] args Array of pointers to kernel arguments.
|
||||
/// \param [in] local_mem_size The size of local memory required by the
|
||||
/// kernel function. \param [in] que SYCL queue used to execute kernel.
|
||||
template <typename FuncT>
|
||||
static std::enable_if_t<std::is_function_v<FuncT>, void>
|
||||
launch(FuncT *func, dim3 group_range, dim3 local_range, void **args,
|
||||
unsigned int local_mem_size, queue_ptr que) {
|
||||
constexpr size_t p_num = args_selector<0, 0, FuncT>::params_num;
|
||||
set_execution_config(group_range, local_range, local_mem_size, que);
|
||||
args_selector<p_num, p_num, FuncT> selector(args, nullptr);
|
||||
launch_helper(func, selector, std::make_index_sequence<p_num>{});
|
||||
}
|
||||
}; // COPY from DPCT head file
|
||||
// /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/kernel.hpp
|
||||
|
||||
// /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp
|
||||
template <typename T>
|
||||
T select_from_sub_group(
|
||||
sycl::sub_group g,
|
||||
T x,
|
||||
int remote_local_id,
|
||||
int logical_sub_group_size = 32) {
|
||||
unsigned int start_index = g.get_local_linear_id() /
|
||||
logical_sub_group_size *
|
||||
logical_sub_group_size;
|
||||
return sycl::select_from_group(
|
||||
g, x, start_index + remote_local_id % logical_sub_group_size);
|
||||
}
|
||||
|
||||
// /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp
|
||||
template <typename T>
|
||||
void ldmatrix(uintptr_t addr, T* m, bool trans = false, unsigned mat = 0) {
|
||||
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
|
||||
int lane = sg.get_local_linear_id();
|
||||
|
||||
int lane_group8_row = lane / 8;
|
||||
int lane_group8_col = lane % 8;
|
||||
|
||||
if (!trans) {
|
||||
// calculate the source lane
|
||||
int src_lane = 2 * lane_group8_row;
|
||||
if (lane_group8_col >= 4)
|
||||
src_lane += 1;
|
||||
|
||||
// Broadcast the address from the source lane
|
||||
auto recv_addr_uintp =
|
||||
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
|
||||
|
||||
// Cast the received address from uintptr_t to the type of 'm'
|
||||
auto recv_addr = reinterpret_cast<T*>(recv_addr_uintp);
|
||||
|
||||
// Non-transposed load
|
||||
*m = recv_addr[lane_group8_col % 4];
|
||||
} else {
|
||||
// calculate the source lane
|
||||
int src_lane = (lane % 4) * 2;
|
||||
|
||||
// Broadcast the address from the source lane
|
||||
auto recv_addr_uintp_1 =
|
||||
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
|
||||
auto recv_addr_uintp_2 =
|
||||
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane + 1);
|
||||
|
||||
// Cast the received address from uintptr_t to 'half *'
|
||||
auto recv_addr_1 = reinterpret_cast<sycl::half*>(recv_addr_uintp_1);
|
||||
auto recv_addr_2 = reinterpret_cast<sycl::half*>(recv_addr_uintp_2);
|
||||
|
||||
// Transposed load
|
||||
int index = lane / 4;
|
||||
sycl::half val0 = recv_addr_1[index];
|
||||
sycl::half val1 = recv_addr_2[index];
|
||||
|
||||
// Combine the two 16-bits into one 32-bit value
|
||||
sycl::half2 val = sycl::half2(val0, val1);
|
||||
*m = *reinterpret_cast<T*>(&val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ldmatrix(uintptr_t addr, T* m1, T* m2, bool trans = false) {
|
||||
// Load 1st matrix
|
||||
ldmatrix(addr, m1, trans, 0);
|
||||
// Load 2nd matrix
|
||||
ldmatrix(addr, m2, trans, 1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ldmatrix(
|
||||
uintptr_t addr, T* m1, T* m2, T* m3, T* m4, bool trans = false) {
|
||||
// Load 1st matrix
|
||||
ldmatrix(addr, m1, trans, 0);
|
||||
// Load 2nd matrix
|
||||
ldmatrix(addr, m2, trans, 1);
|
||||
// Load 3rd matrix
|
||||
ldmatrix(addr, m3, trans, 2);
|
||||
// Load 4th matrix
|
||||
ldmatrix(addr, m4, trans, 3);
|
||||
}
|
||||
|
||||
// /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp
|
||||
|
||||
/// A helper struct that defines the pack type for the input matrix
|
||||
/// fragments
|
||||
/// of mma() function based on the type of input matrix fragments.
|
||||
/// The MMAType struct is specialized for different types of input matrices.
|
||||
/// Currently, the specialization for f16, bf16 and s8 types is defined
|
||||
/// below. \tparam [in] T The type of the input matrix fragments
|
||||
template <typename T>
|
||||
struct MMAType {
|
||||
using PackType = uint32_t;
|
||||
};
|
||||
|
||||
/// Each work item of a sub-group (limited to size 32) calling this function
|
||||
/// calculates a subset fragment for the output matrix D using MAD operation
|
||||
/// on A, B & C matrix fragments (D = A * B + C). Current supported shapes &
|
||||
/// types:
|
||||
/// - m8n8k4 (f32.f16.f16.f32)
|
||||
/// - m8n8k16 (s32.s8.s8.s32)
|
||||
/// - m16n8k8 (f32.f16.f16.f32 & f32.bf16.bf16.f32)
|
||||
/// - m16n8k16 (f32.f16.f16.f32 & s32.s8.s8.s32)
|
||||
/// - m16n8k32 (s32.s8.s8.s32)
|
||||
/// Here, m, n & k define the shapes of A, B & C matrices respectively
|
||||
/// (A = [m x k], B = [k x n], C = [m x n]).
|
||||
/// \tparam [in] M The rows of A, C & D matrices
|
||||
/// \tparam [in] N The columns of B, C, D matrices
|
||||
/// \tparam [in] K The columns & rows of A & B matrices respectively
|
||||
/// \tparam [in] ABType The type of the input matrix (A & B) fragment
|
||||
/// \tparam [in] CDType The type of the output matrix (C & D) fragment
|
||||
/// \param [out] d_mat_frag The fragment of the output matrix D to store the
|
||||
/// result of A * B + C
|
||||
/// \param [in] a_mat_frag The fragment of the input matrix A to be
|
||||
/// multiplied with B matrix fragment \param [in] b_mat_frag The fragment of
|
||||
/// the input matrix B to be multiplied with A matrix fragment \param [in]
|
||||
/// c_mat_frag The fragment of the input matrix C to be added with the
|
||||
/// result of A * B fragments
|
||||
template <int M, int N, int K, typename ABType, typename CDType>
|
||||
void mma(
|
||||
volatile void** d_mat_frag,
|
||||
void* a_mat_frag,
|
||||
void* b_mat_frag,
|
||||
void* c_mat_frag) {
|
||||
auto d = reinterpret_cast<volatile CDType**>(d_mat_frag);
|
||||
auto a =
|
||||
reinterpret_cast<typename MMAType<ABType>::PackType*>(a_mat_frag);
|
||||
auto b =
|
||||
reinterpret_cast<typename MMAType<ABType>::PackType*>(b_mat_frag);
|
||||
auto c = reinterpret_cast<CDType*>(c_mat_frag);
|
||||
|
||||
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
|
||||
int lane = sg.get_local_linear_id();
|
||||
|
||||
static_assert(
|
||||
(M == 8 && N == 8 && K == 4) || (M == 8 && N == 8 && K == 16) ||
|
||||
(M == 16 && N == 8 && K == 8) || (M == 16 && N == 8 && K == 16) ||
|
||||
(M == 16 && N == 8 && K == 32),
|
||||
"Unsupported MMA shape!");
|
||||
|
||||
short row_load_offset = 4 * (lane >> 2);
|
||||
short col_load_offset = 8 * (lane % 4);
|
||||
|
||||
if constexpr (M == 8 && N == 8 && K == 4) {
|
||||
if constexpr (std::is_floating_point_v<CDType>) {
|
||||
col_load_offset = row_load_offset % 16;
|
||||
|
||||
// Init D matrix with fragments of C matrix
|
||||
*d[0] = c[0];
|
||||
*d[1] = c[1];
|
||||
*d[2] = c[2];
|
||||
*d[3] = c[3];
|
||||
*d[4] = c[4];
|
||||
*d[5] = c[5];
|
||||
*d[6] = c[6];
|
||||
*d[7] = c[7];
|
||||
|
||||
// Calculate the row and col offset indices to iterate through the row
|
||||
// & col fragments of A & B matrices
|
||||
int r_ind = (lane % 2) ? 1 : 0;
|
||||
int c_ind = ((lane % 4) / 2) ? 2 : 0;
|
||||
|
||||
// Each sub-group is responsible for computing a fragment size of 8*8
|
||||
// elements of matrix D for each of 4 MMA computations.
|
||||
// Each work item computes 8 elements of matrix D by gathering
|
||||
// their corresponding col & row matrix fragments of length k (4)
|
||||
// from A & B matrices respectively using below mapping logic:
|
||||
// row0 = (i % 4) if (lane < 16) else (i % 4) + 4
|
||||
// col0 = (lane % 4)
|
||||
// As each row & col fragment of A & B matrices is distributed across
|
||||
// 4 work items, each iteration of below loop loads a partial fragment
|
||||
// of matrix A (row) and matrix B (col) using the row & col offsets.
|
||||
typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
// Load partial fragment from col0 of matrix A ({a0, a1})
|
||||
recv_a[0] =
|
||||
dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
|
||||
// Load partial fragment from col0 of matrix A ({a2, a3})
|
||||
recv_a[1] =
|
||||
dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
|
||||
|
||||
// Load partial fragment from row0 of matrix B ({b0, b1})
|
||||
recv_b[0] =
|
||||
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
|
||||
// Load partial fragment from row0 of matrix B ({b2, b3})
|
||||
recv_b[1] =
|
||||
dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
|
||||
|
||||
auto ra = reinterpret_cast<ABType*>(recv_a);
|
||||
auto rb = reinterpret_cast<ABType*>(recv_b);
|
||||
|
||||
// Each work item calculates a partial product of A & B matrix
|
||||
// fragments and adds it to the corresponding D matrix fragment (for
|
||||
// even work item indices) d0 += col0{ a0 } * row0{ b0 } d1 += col0{
|
||||
// a0 } * row0{ b1 } d2 += col1{ a2 } * row0{ b0 } d3 += col1{ a2 }
|
||||
// * row0{ b1 } (for odd work item indices) d0 += col0{ a1 } * row0{
|
||||
// b2 } d1 += col0{ a1 } * row0{ b3 } d2 += col1{ a3 } * row0{ b2 }
|
||||
// d3 += col1{ a3 } * row0{ b3 }
|
||||
*d[0] +=
|
||||
static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]);
|
||||
*d[1] += static_cast<float>(ra[r_ind]) *
|
||||
static_cast<float>(rb[c_ind + 1]);
|
||||
*d[2] += static_cast<float>(ra[r_ind + 2]) *
|
||||
static_cast<float>(rb[c_ind]);
|
||||
*d[3] += static_cast<float>(ra[r_ind + 2]) *
|
||||
static_cast<float>(rb[c_ind + 1]);
|
||||
|
||||
// Load partial fragment from row1 of matrix B ({b0, b1})
|
||||
recv_b[0] =
|
||||
dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 16);
|
||||
// Load partial fragment from row1 of matrix B ({b2, b3})
|
||||
recv_b[1] =
|
||||
dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 16);
|
||||
|
||||
// (for even work item indices)
|
||||
// d0 += col0{ a0 } * row1{ b0 }
|
||||
// d1 += col0{ a0 } * row1{ b1 }
|
||||
// d2 += col1{ a2 } * row1{ b0 }
|
||||
// d3 += col1{ a2 } * row1{ b1 }
|
||||
// (for odd work item indices)
|
||||
// d0 += col0{ a1 } * row1{ b2 }
|
||||
// d1 += col0{ a1 } * row1{ b3 }
|
||||
// d2 += col1{ a3 } * row1{ b2 }
|
||||
// d3 += col1{ a3 } * row1{ b3 }
|
||||
*d[4] +=
|
||||
static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]);
|
||||
*d[5] += static_cast<float>(ra[r_ind]) *
|
||||
static_cast<float>(rb[c_ind + 1]);
|
||||
*d[6] += static_cast<float>(ra[r_ind + 2]) *
|
||||
static_cast<float>(rb[c_ind]);
|
||||
*d[7] += static_cast<float>(ra[r_ind + 2]) *
|
||||
static_cast<float>(rb[c_ind + 1]);
|
||||
}
|
||||
}
|
||||
} else if constexpr (M == 8 && N == 8 && K == 16) {
|
||||
if constexpr (std::is_integral_v<ABType>) {
|
||||
// Init D matrix with fragments of C matrix
|
||||
*d[0] = c[0];
|
||||
*d[1] = c[1];
|
||||
|
||||
// Each sub-group is responsible for computing a fragment size of 16*8
|
||||
// elements of matrix D.
|
||||
// Each work item computes 2 elements of matrix D by gathering
|
||||
// their corresponding row & col matrix fragments of length k (16)
|
||||
// from A & B matrices respectively using below mapping logic:
|
||||
// row0 = ((lane % 4) * 4) + i
|
||||
// col0 = (lane >> 2)
|
||||
// As each row & col fragment of A & B matrices is distributed across
|
||||
// 4 work items, each iteration of below loop loads a partial fragment
|
||||
// of matrix A (row) and matrix B (col) using the row & col offsets.
|
||||
for (int i = 0; i < 4; i++) {
|
||||
typename MMAType<ABType>::PackType recv_a, recv_b[2];
|
||||
|
||||
// Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
|
||||
recv_a = dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
|
||||
// Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
|
||||
recv_b[0] =
|
||||
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
|
||||
// Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})
|
||||
recv_b[1] =
|
||||
dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
|
||||
|
||||
auto a = reinterpret_cast<ABType*>(&recv_a);
|
||||
auto b = reinterpret_cast<ABType*>(recv_b);
|
||||
|
||||
// Each work item calculates a partial product of A & B matrix
|
||||
// fragments and adds it to the corresponding D matrix fragment d0
|
||||
// += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
|
||||
// a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row0{ a0, a1, a2,
|
||||
// a3 } * col0{ b0, b1, b2, b3 } d3 += row0{ a0, a1, a2, a3 } *
|
||||
// col1{ b0, b1, b2, b3 }
|
||||
for (int j = 0; j < 4; j++) {
|
||||
*d[0] += a[j] * b[j];
|
||||
*d[1] += a[j] * b[j + 4];
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if constexpr (M == 16 && N == 8 && K == 8) {
|
||||
if constexpr (std::is_floating_point_v<CDType>) {
|
||||
// Init D matrix fragment with C matrix fragment
|
||||
*d[0] = c[0];
|
||||
*d[1] = c[1];
|
||||
*d[2] = c[2];
|
||||
*d[3] = c[3];
|
||||
|
||||
// Each sub-group is responsible for computing a fragment size of 16*8
|
||||
// elements of matrix D.
|
||||
// Each work item computes 4 elements of matrix D by gathering
|
||||
// their corresponding row & col matrix fragments of length k (8)
|
||||
// from A & B matrices respectively using below mapping logic:
|
||||
// row0 = (lane >> 2) & row1 = (lane >> 2) + 8
|
||||
// col0 = (lane % 4) * 2 + (i & 0x1)
|
||||
// As each row & col fragment of A & B matrices is distributed across
|
||||
// 4 work items, each iteration of below loop loads a partial fragment
|
||||
// of matrix A (row) and matrix B (col) using the row & col offsets.
|
||||
for (int i = 0; i < 4; i++) {
|
||||
typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
|
||||
|
||||
// Load partial fragment from row0 of matrix A ({a0, a1})
|
||||
recv_a[0] =
|
||||
dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
|
||||
// Load partial fragment from row1 of matrix A ({a2, a3})
|
||||
recv_a[1] =
|
||||
dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
|
||||
// Load partial fragment from col0 of matrix B ({b0, b1})
|
||||
recv_b[0] =
|
||||
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
|
||||
// Load partial fragment from col1 of matrix B ({b0, b1})
|
||||
recv_b[1] =
|
||||
dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
|
||||
|
||||
auto ra = reinterpret_cast<ABType*>(recv_a);
|
||||
auto rb = reinterpret_cast<ABType*>(recv_b);
|
||||
|
||||
// Each work item calculates a partial product of A & B matrix
|
||||
// fragments and adds it to the corresponding D matrix fragment d0
|
||||
// += row0{ a0, a1 } * col0{ b0, b1 } d1 += row0{ a0, a1 } * col1{
|
||||
// b0, b1 } d2 += row1{ a2, a3 } * col0{ b0, b1 } d3 += row1{ a2, a3
|
||||
// } * col1{ b0, b1 }
|
||||
for (int j = 0; j < 2; j++) {
|
||||
*d[0] += static_cast<float>(ra[j]) * static_cast<float>(rb[j]);
|
||||
*d[1] +=
|
||||
static_cast<float>(ra[j]) * static_cast<float>(rb[j + 2]);
|
||||
*d[2] +=
|
||||
static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j]);
|
||||
*d[3] +=
|
||||
static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j + 2]);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if constexpr (M == 16 && N == 8 && K == 16) {
|
||||
if constexpr (std::is_floating_point_v<CDType>) {
|
||||
// Init D matrix fragment with C matrix fragment
|
||||
*d[0] = c[0];
|
||||
*d[1] = c[1];
|
||||
*d[2] = c[2];
|
||||
*d[3] = c[3];
|
||||
|
||||
// Each sub-group is responsible for computing a fragment size of 16*8
|
||||
// elements of matrix D.
|
||||
// Each work item computes 4 elements of matrix D by gathering
|
||||
// their corresponding row & col matrix fragments of length k (8)
|
||||
// from A & B matrices respectively using below mapping logic:
|
||||
// row0 = (lane >> 2) & row1 = (lane >> 2) + 8
|
||||
// col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
|
||||
// As each row & col fragment of A & B matrices is distributed across
|
||||
// 4 work items, each iteration of below loop loads a partial fragment
|
||||
// of matrix A (row) and matrix B (col) using the row & col offsets.
|
||||
for (int i = 0; i < 4; i++) {
|
||||
typename MMAType<ABType>::PackType recv_a[4], recv_b[4];
|
||||
|
||||
// Load partial fragment from row0 of matrix A ({a0, a1})
|
||||
recv_a[0] =
|
||||
dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
|
||||
// Load partial fragment from row0 of matrix A ({a2, a3})
|
||||
recv_a[1] =
|
||||
dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
|
||||
// Load partial fragment from row1 of matrix A ({a0, a1})
|
||||
recv_a[2] =
|
||||
dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
|
||||
// Load partial fragment from row1 of matrix A ({a2, a3})
|
||||
recv_a[3] =
|
||||
dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
|
||||
|
||||
// Load partial fragment from col0 of matrix B ({b0, b1})
|
||||
recv_b[0] =
|
||||
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
|
||||
// Load partial fragment from col0 of matrix B ({b2, b3})
|
||||
recv_b[1] =
|
||||
dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
|
||||
// Load partial fragment from col1 of matrix B ({b0, b1})
|
||||
recv_b[2] =
|
||||
dpct::select_from_sub_group(sg, b[0], col_load_offset + 4 + i);
|
||||
// Load partial fragment from col1 of matrix B ({b2, b3})
|
||||
recv_b[3] =
|
||||
dpct::select_from_sub_group(sg, b[1], col_load_offset + 4 + i);
|
||||
|
||||
auto ra = reinterpret_cast<ABType*>(recv_a);
|
||||
auto rb = reinterpret_cast<ABType*>(recv_b);
|
||||
|
||||
// Each work item calculates a partial product of A & B matrix
|
||||
// fragments and adds it to the corresponding D matrix fragment d0
|
||||
// += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
|
||||
// a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a0, a1, a2,
|
||||
// a3 } * col0{ b0, b1, b2, b3 } d3 += row1{ a0, a1, a2, a3 } *
|
||||
// col1{ b0, b1, b2, b3 }
|
||||
for (int j = 0; j < 4; j++) {
|
||||
*d[0] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]);
|
||||
*d[1] +=
|
||||
static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j + 4]);
|
||||
*d[2] +=
|
||||
static_cast<CDType>(ra[j + 4]) * static_cast<CDType>(rb[j]);
|
||||
*d[3] += static_cast<CDType>(ra[j + 4]) *
|
||||
static_cast<CDType>(rb[j + 4]);
|
||||
}
|
||||
}
|
||||
} else if constexpr (std::is_integral_v<ABType>) {
|
||||
// Init D matrix with fragments of C matrix
|
||||
*d[0] = c[0];
|
||||
*d[1] = c[1];
|
||||
*d[2] = c[2];
|
||||
*d[3] = c[3];
|
||||
|
||||
// Each sub-group is responsible for computing a fragment size of 16*8
|
||||
// elements of matrix D.
|
||||
// Each work item computes 4 elements of matrix D by gathering
|
||||
// their corresponding row & col matrix fragments of length k (8)
|
||||
// from A & B matrices respectively using below mapping logic:
|
||||
// row0 = (lane >> 2) & row1 = (lane >> 2) + 8
|
||||
// col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
|
||||
// As each row & col fragment of A & B matrices is distributed across
|
||||
// 4 work items, each iteration of below loop loads a partial fragment
|
||||
// of matrix A (row) and matrix B (col) using the row & col offsets.
|
||||
for (int i = 0; i < 4; i++) {
|
||||
typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
|
||||
|
||||
// Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
|
||||
recv_a[0] =
|
||||
dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
|
||||
// Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})
|
||||
recv_a[1] =
|
||||
dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
|
||||
// Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
|
||||
recv_b[0] =
|
||||
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
|
||||
// Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})
|
||||
recv_b[1] =
|
||||
dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
|
||||
|
||||
auto ra = reinterpret_cast<ABType*>(recv_a);
|
||||
auto rb = reinterpret_cast<ABType*>(recv_b);
|
||||
|
||||
// Each work item calculates a partial product of A & B matrix
|
||||
// fragments and adds it to the corresponding D matrix fragment d0
|
||||
// += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
|
||||
// a0, a1, a2, a3 } * col1{ b4, b5, b6, b7 } d2 += row1{ a4, a5, a6,
|
||||
// a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *
|
||||
// col1{ b4, b5, b6, b7 }
|
||||
for (int i = 0; i < 4; i++) {
|
||||
*d[0] += ra[i] * rb[i];
|
||||
*d[1] += ra[i] * rb[i + 4];
|
||||
*d[2] += ra[i + 4] * rb[i];
|
||||
*d[3] += ra[i + 4] * rb[i + 4];
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if constexpr (M == 16 && N == 8 && K == 32) {
|
||||
if constexpr (std::is_integral_v<ABType>) {
|
||||
// Init D matrix with fragments of C matrix
|
||||
*d[0] = c[0];
|
||||
*d[1] = c[1];
|
||||
*d[2] = c[2];
|
||||
*d[3] = c[3];
|
||||
|
||||
// Each sub-group is responsible for computing a fragment size of 16*8
|
||||
// elements of matrix D.
|
||||
// Each work item computes 4 elements of matrix D by gathering
|
||||
// their corresponding row & col matrix fragments of length k (32)
|
||||
// from A & B matrices respectively using below mapping logic:
|
||||
// row0 = (lane >> 2) & row1 = (lane >> 2) + 8
|
||||
// col0 = ((lane % 4) * 4) + (i & 0x3) & col1 = ((lane % 4) * 4) + (i
|
||||
// & 0x3) As each row & col fragment of A & B matrices is distributed
|
||||
// across 4 work items, each iteration of below loop loads a partial
|
||||
// fragment of matrix A (row) and matrix B (col) using the row & col
|
||||
// offsets.
|
||||
for (int i = 0; i < 4; i++) {
|
||||
typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
|
||||
|
||||
// Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
|
||||
recv_a[0] =
|
||||
dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
|
||||
// Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})
|
||||
recv_a[1] =
|
||||
dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
|
||||
// Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
|
||||
recv_b[0] =
|
||||
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
|
||||
// Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})
|
||||
recv_b[1] =
|
||||
dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
|
||||
|
||||
auto a = reinterpret_cast<ABType*>(recv_a);
|
||||
auto b = reinterpret_cast<ABType*>(recv_b);
|
||||
|
||||
// Each work item calculates a partial product of A & B matrix
|
||||
// fragments and adds it to the corresponding D matrix fragment d0
|
||||
// += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
|
||||
// a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a4, a5, a6,
|
||||
// a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *
|
||||
// col1{ b0, b1, b2, b3 }
|
||||
for (int j = 0; j < 4; j++) {
|
||||
*d[0] += a[j] * b[j];
|
||||
*d[1] += a[j] * b[j + 4];
|
||||
*d[2] += a[j + 4] * b[j];
|
||||
*d[3] += a[j + 4] * b[j + 4];
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
|
||||
|
||||
// Load partial fragment from row0 of matrix A ({a8, a9, a10, a11})
|
||||
recv_a[0] =
|
||||
dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
|
||||
// Load partial fragment from row1 of matrix A ({a12, a13, a14,
|
||||
// a15})
|
||||
recv_a[1] =
|
||||
dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
|
||||
// Load partial fragment from col0 of matrix B ({b4, b5, b6, b7})
|
||||
recv_b[0] =
|
||||
dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
|
||||
// Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})
|
||||
recv_b[1] =
|
||||
dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 4);
|
||||
|
||||
auto a = reinterpret_cast<ABType*>(recv_a);
|
||||
auto b = reinterpret_cast<ABType*>(recv_b);
|
||||
|
||||
// Each work item calculates a partial product of A & B matrix
|
||||
// fragments and adds it to the corresponding D matrix fragment d0
|
||||
// += row0{ a8, a9, a10, a11 } * col0{ b4, b5, b6, b7 } d1 += row0{
|
||||
// a8, a9, a10, a11 } * col1{ b4, b5, b6, b7 } d2 += row1{ a12, a13,
|
||||
// a14, a15 } * col0{ b4, b5, b6, b7 } d3 += row1{ a12, a13, a14,
|
||||
// a15 } * col1{ b4, b5, b6, b7 }
|
||||
for (int j = 0; j < 4; j++) {
|
||||
*d[0] += a[j] * b[j];
|
||||
*d[1] += a[j] * b[j + 4];
|
||||
*d[2] += a[j + 4] * b[j];
|
||||
*d[3] += a[j + 4] * b[j + 4];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // COPY from DPCT head files
|
||||
|
||||
#endif // GGML_SYCL_DPCT_HELPER_HPP
|
||||
|
||||
1179
ggml/src/ggml-sycl/fattn-common.hpp
Normal file
1179
ggml/src/ggml-sycl/fattn-common.hpp
Normal file
File diff suppressed because it is too large
Load Diff
55
ggml/src/ggml-sycl/fattn-tile.cpp
Normal file
55
ggml/src/ggml-sycl/fattn-tile.cpp
Normal file
@@ -0,0 +1,55 @@
|
||||
#include <sycl/sycl.hpp>
|
||||
#include <sycl/ext/oneapi/work_group_static.hpp>
|
||||
#include "dpct/helper.hpp"
|
||||
#include "common.hpp"
|
||||
#include "fattn-common.hpp"
|
||||
#include "fattn-tile.hpp"
|
||||
#include <cmath>
|
||||
#include <float.h>
|
||||
namespace syclex = sycl::ext::oneapi::experimental;
|
||||
|
||||
void ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
switch (K->ne[0]) {
|
||||
case 40: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_sycl_flash_attn_ext_tile_case< 40, 40>(ctx, dst);
|
||||
} break;
|
||||
case 64: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_sycl_flash_attn_ext_tile_case< 64, 64>(ctx, dst);
|
||||
} break;
|
||||
case 72: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_sycl_flash_attn_ext_tile_case< 72, 72>(ctx, dst);
|
||||
} break;
|
||||
case 80: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_sycl_flash_attn_ext_tile_case< 80, 80>(ctx, dst);
|
||||
} break;
|
||||
case 96: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_sycl_flash_attn_ext_tile_case< 96, 96>(ctx, dst);
|
||||
} break;
|
||||
case 112: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_sycl_flash_attn_ext_tile_case<112, 112>(ctx, dst);
|
||||
} break;
|
||||
case 128: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_sycl_flash_attn_ext_tile_case<128, 128>(ctx, dst);
|
||||
} break;
|
||||
case 256: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_sycl_flash_attn_ext_tile_case<256, 256>(ctx, dst);
|
||||
} break;
|
||||
case 576: {
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
ggml_sycl_flash_attn_ext_tile_case<576, 512>(ctx, dst);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("Unsupported head size");
|
||||
} break;
|
||||
}
|
||||
}
|
||||
1338
ggml/src/ggml-sycl/fattn-tile.hpp
Normal file
1338
ggml/src/ggml-sycl/fattn-tile.hpp
Normal file
File diff suppressed because it is too large
Load Diff
667
ggml/src/ggml-sycl/fattn-vec.hpp
Normal file
667
ggml/src/ggml-sycl/fattn-vec.hpp
Normal file
@@ -0,0 +1,667 @@
|
||||
#ifndef GGML_SYCL_FATTN_VEC_HPP
|
||||
#define GGML_SYCL_FATTN_VEC_HPP
|
||||
|
||||
#include <sycl/sycl.hpp>
|
||||
#include <sycl/ext/oneapi/work_group_static.hpp>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "dpct/helper.hpp"
|
||||
#include "common.hpp"
|
||||
#include "ggml.h"
|
||||
#include "fattn-common.hpp"
|
||||
#include <cmath>
|
||||
#include <float.h>
|
||||
|
||||
namespace syclex = sycl::ext::oneapi::experimental;
|
||||
|
||||
static int ggml_sycl_fattn_vec_get_nthreads_host(const int cc) {
|
||||
return 128;
|
||||
GGML_UNUSED(cc);
|
||||
}
|
||||
|
||||
static constexpr int ggml_sycl_fattn_vec_get_nthreads_device() {
|
||||
return 128;
|
||||
}
|
||||
|
||||
// Currenlty llvm with the amdgcn target dose not support unrolling loops
|
||||
// that contain a break that can not be resolved at compile time.
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wpass-failed"
|
||||
#endif // __clang__
|
||||
|
||||
template <int D,
|
||||
int ncols,
|
||||
int type_K,
|
||||
int type_V,
|
||||
bool use_logit_softcap,
|
||||
int warp_size> // D == head size
|
||||
static void flash_attn_ext_vec(const char* __restrict__ Q,
|
||||
const char* __restrict__ K,
|
||||
const char* __restrict__ V,
|
||||
const char* __restrict__ mask,
|
||||
const char* __restrict__ sinks,
|
||||
const int* __restrict__ KV_max,
|
||||
float* __restrict__ dst,
|
||||
sycl::float2* __restrict__ dst_meta,
|
||||
const float scale,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int32_t ne00,
|
||||
const sycl::uint3 ne01,
|
||||
const int32_t ne02,
|
||||
const int32_t ne03,
|
||||
const int32_t nb01,
|
||||
const int32_t nb02,
|
||||
const int32_t nb03,
|
||||
const int32_t ne10,
|
||||
const int32_t ne11,
|
||||
const int32_t ne12,
|
||||
const int32_t ne13,
|
||||
const int32_t nb11,
|
||||
const int32_t nb12,
|
||||
const int64_t nb13,
|
||||
const int32_t nb21,
|
||||
const int32_t nb22,
|
||||
const int64_t nb23,
|
||||
const int32_t ne31,
|
||||
const int32_t ne32,
|
||||
const int32_t ne33,
|
||||
const int32_t nb31,
|
||||
const int32_t nb32,
|
||||
const int64_t nb33) {
|
||||
#ifdef SYCL_FLASH_ATTN
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
||||
max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
return;
|
||||
}
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
|
||||
constexpr int cpy_ne = cpy_nb / 4;
|
||||
|
||||
constexpr int nthreads_KQ_q = (D/4 < warp_size ? D/4 : warp_size);
|
||||
constexpr int nthreads_V_q = (D/4 < warp_size ? D/4 : warp_size);
|
||||
|
||||
constexpr int nthreads = ggml_sycl_fattn_vec_get_nthreads_device();
|
||||
constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q;
|
||||
constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q;
|
||||
|
||||
static_assert(warp_size % nthreads_KQ == 0, "bad nthreads_K");
|
||||
static_assert(warp_size % nthreads_V == 0, "bad nthreads_V");
|
||||
|
||||
constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4;
|
||||
constexpr int V_cols_per_iter = warp_size / nthreads_V;
|
||||
|
||||
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ, warp_size>();
|
||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
||||
#ifdef GGML_SYCL_F16
|
||||
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, sycl::half, V_rows_per_thread>();
|
||||
#else
|
||||
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
|
||||
#endif // GGML_SYCL_F16
|
||||
|
||||
const int ic0 = item_ct1.get_group(2) * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = item_ct1.get_group(0) / ne02;
|
||||
const int head = item_ct1.get_group(0) - sequence * ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
Q += nb03*sequence + nb02* head + nb01*ic0;
|
||||
K += nb13*sequence + nb12*(head / gqa_ratio);
|
||||
V += nb23*sequence + nb22*(head / gqa_ratio);
|
||||
|
||||
const sycl::half * maskh = (const sycl::half *) (mask + nb33 * (sequence % ne33) + nb31 * ic0);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
|
||||
static_assert(D % (2*warp_size) == 0, "D not divisible by 2*warp_size == 64.");
|
||||
constexpr int nwarps = nthreads / warp_size;
|
||||
const int tid = warp_size * item_ct1.get_local_id(1) + item_ct1.get_local_id(2);
|
||||
__builtin_assume(tid < nthreads);
|
||||
|
||||
constexpr int ne_KQ = ncols*D;
|
||||
constexpr int ne_combine = nwarps*V_cols_per_iter*D;
|
||||
|
||||
constexpr size_t lsm_size1 = ncols * warp_size;
|
||||
constexpr size_t lsm_size2 = ncols * warp_size;
|
||||
#ifdef GGML_SYCL_F16
|
||||
sycl::half2 VKQ[ncols][(D / 2) / nthreads_V] = { { { 0.0f, 0.0f } } };
|
||||
constexpr size_t lsm_size3 = (ne_KQ > ne_combine ? ne_KQ : ne_combine);
|
||||
constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2)*sizeof(float) + lsm_size3*sizeof(sycl::half);
|
||||
|
||||
syclex::work_group_static<char[local_share_mem_size]> lsm;
|
||||
|
||||
float *KQ_max_shared = (float *)&lsm;
|
||||
float *KQ_sum_shared = KQ_max_shared+lsm_size1;
|
||||
sycl::half* KQ = (sycl::half*)(KQ_sum_shared + lsm_size2);
|
||||
|
||||
|
||||
#else
|
||||
sycl::float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
||||
|
||||
constexpr size_t lsm_size3 = (ne_KQ > ne_combine ? ne_KQ : ne_combine);
|
||||
constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2 + lsm_size3)*sizeof(float);
|
||||
|
||||
|
||||
syclex::work_group_static<char[local_share_mem_size]> lsm;
|
||||
float *KQ_max_shared = (float *)&lsm;
|
||||
float *KQ_sum_shared = KQ_max_shared+lsm_size1;
|
||||
float* KQ = KQ_sum_shared + lsm_size2;
|
||||
|
||||
#endif // GGML_SYCL_F16
|
||||
|
||||
float KQ_max[ncols];
|
||||
float KQ_sum[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
KQ_max[j] = -FLT_MAX/2.0f;
|
||||
KQ_sum[j] = 0.0f;
|
||||
}
|
||||
|
||||
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
|
||||
#ifdef GGML_SYCL_F16
|
||||
sycl::half2 Q_reg[ncols][(D / 2) / nthreads_KQ] = {{{0.0f, 0.0f}}}; // Will be initialized completely.
|
||||
#else
|
||||
sycl::float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
|
||||
#endif // GGML_SYCL_F16
|
||||
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
||||
sycl::float2 Q_ds[ncols][1 > D / (sizeof(int) * nthreads_KQ) ? 1 : D / (sizeof(int) * nthreads_KQ)];
|
||||
if constexpr (Q_q8_1) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + item_ct1.get_local_id(1);
|
||||
|
||||
if (j0 + nwarps > ncols && j >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Reuse KQ as temporary storage for converting Q to q8_1:
|
||||
int * tmp_q_i32 = (int *) &KQ[j*D];
|
||||
sycl::float2 * tmp_q_ds = (sycl::float2 *) (tmp_q_i32 + D / sizeof(int));
|
||||
|
||||
// Set memory to zero if out of bounds:
|
||||
if (ncols > 1 && ic0 + j >= int(ne01.z())) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += warp_size) {
|
||||
const int i = i0 + item_ct1.get_local_id(2);
|
||||
|
||||
if (i0 + warp_size <= int(D/sizeof(int)) || i < int(D/sizeof(int))) {
|
||||
tmp_q_i32[i] = 0;
|
||||
}
|
||||
}
|
||||
if (item_ct1.get_local_id(2) < D/QK8_1) {
|
||||
tmp_q_ds[item_ct1.get_local_id(2)] = sycl::float2(0.0f, 0.0f);
|
||||
}
|
||||
} else {
|
||||
const float * Q_f = (const float *) (Q + j*nb01);
|
||||
constexpr int nthreads_quantize = D/sizeof(int) < warp_size ? D/sizeof(int) : warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_quantize) {
|
||||
quantize_q8_1_to_shared<sycl::float2, nthreads_quantize, warp_size>
|
||||
(Q_f + i0*sizeof(int), scale, tmp_q_i32 + i0, tmp_q_ds + i0/QI8_1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
int * tmp_q_i32 = (int *) &KQ[j*D];
|
||||
sycl::float2 * tmp_q_ds = (sycl::float2 *) (tmp_q_i32 + D / sizeof(int));
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_KQ) {
|
||||
const int i =
|
||||
i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_KQ);
|
||||
|
||||
Q_i32[j][i0/nthreads_KQ] = tmp_q_i32[i];
|
||||
Q_ds[j][i0/nthreads_KQ] = tmp_q_ds[i/QI8_1];
|
||||
}
|
||||
}
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
} else {
|
||||
#ifdef GGML_SYCL_F16
|
||||
const sycl::half2 scale_h2 = sycl::half2(scale, scale);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
const sycl::float2 * Q_j = (const sycl::float2 *) (Q + j * nb01);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
|
||||
const int i = i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) :
|
||||
item_ct1.get_local_id(2) % nthreads_KQ) *
|
||||
cpy_ne;
|
||||
|
||||
sycl::float2 tmp[cpy_ne] = {
|
||||
{ 0.0f, 0.0f }
|
||||
};
|
||||
if (ncols == 1 || ic0 + j < int(ne01.z())) {
|
||||
ggml_sycl_memcpy_1<cpy_nb>(tmp, &Q_j[i]);
|
||||
ggml_sycl_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne; ++i1) {
|
||||
Q_reg[j][i0 / nthreads_KQ + i1] = sycl::half2(tmp[i1].x(), tmp[i1].y());
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
|
||||
Q_reg[j][k] *= scale_h2;
|
||||
}
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
const sycl::float2 * Q_j = (const sycl::float2 *) (Q + j*nb01);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
|
||||
const int i = i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_KQ)*cpy_ne;
|
||||
if (ncols == 1 || ic0 + j < int(ne01.z())) {
|
||||
ggml_sycl_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]);
|
||||
ggml_sycl_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
|
||||
Q_reg[j][k].x() *= scale;
|
||||
Q_reg[j][k].y() *= scale;
|
||||
}
|
||||
}
|
||||
#endif // GGML_SYCL_F16
|
||||
}
|
||||
|
||||
const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11;
|
||||
K += item_ct1.get_group(1) * nthreads * nb11;
|
||||
V += item_ct1.get_group(1) * nthreads * nb21;
|
||||
maskh += item_ct1.get_group(1) * nthreads;
|
||||
for (int k_VKQ_0 = item_ct1.get_group(1) * nthreads; k_VKQ_0 < k_VKQ_max;
|
||||
k_VKQ_0 += item_ct1.get_group_range(1) * nthreads,
|
||||
// Increment pointers after each loop:
|
||||
K += item_ct1.get_group_range(1) * nthreads * nb11, V += item_ct1.get_group_range(1) * nthreads * nb21,
|
||||
maskh += item_ct1.get_group_range(1) * nthreads) {
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
float KQ_reg[ncols]={}; // KQ in registers.
|
||||
float KQ_max_new[ncols]={};
|
||||
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
KQ_max_new[j] = KQ_max[j];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) {
|
||||
const int i_KQ = item_ct1.get_local_id(1) * warp_size +
|
||||
(nthreads_KQ == warp_size ? 0 : (item_ct1.get_local_id(2) & ~(nthreads_KQ - 1))) + i_KQ_0;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]);
|
||||
sum = warp_reduce_sum<nthreads_KQ>(sum);
|
||||
|
||||
if (use_logit_softcap) {
|
||||
sum = logit_softcap * sycl::tanh(sum);
|
||||
}
|
||||
if (mask) {
|
||||
sum += slope * sycl::vec<sycl::half, 1>(maskh[j * ne11 + i_KQ])
|
||||
.convert<float, sycl::rounding_mode::automatic>()[0];
|
||||
}
|
||||
|
||||
KQ_max_new[j] = sycl::fmax((float) KQ_max_new[j], sum);
|
||||
|
||||
if (int(nthreads_KQ == warp_size ? item_ct1.get_local_id(2)
|
||||
: item_ct1.get_local_id(2) %
|
||||
nthreads_KQ) == i_KQ_0) {
|
||||
KQ_reg[j] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
#pragma unroll
|
||||
for (int offset = nthreads_KQ; offset < warp_size; offset <<= 1) {
|
||||
KQ_max_new[j] = sycl::fmax(
|
||||
(float)KQ_max_new[j],
|
||||
(float)dpct::permute_sub_group_by_xor(
|
||||
sycl::ext::oneapi::this_work_item::get_sub_group(),
|
||||
KQ_max_new[j],
|
||||
offset,
|
||||
warp_size));
|
||||
}
|
||||
const float KQ_max_scale = sycl::native::exp((float) (KQ_max[j] - KQ_max_new[j]));
|
||||
KQ_max[j] = KQ_max_new[j];
|
||||
|
||||
KQ_reg[j] = sycl::native::exp((float) (KQ_reg[j] - KQ_max[j]));
|
||||
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
|
||||
KQ[j*nthreads + tid] = KQ_reg[j];
|
||||
|
||||
#ifdef GGML_SYCL_F16
|
||||
const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||
VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||
VKQ[j][i_VKQ_0/nthreads_V].x() *= KQ_max_scale;
|
||||
VKQ[j][i_VKQ_0/nthreads_V].y() *= KQ_max_scale;
|
||||
}
|
||||
#endif // GGML_SYCL_F16
|
||||
}
|
||||
|
||||
sycl::group_barrier(sycl::ext::oneapi::this_work_item::get_sub_group());
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < warp_size; k0 += V_cols_per_iter) {
|
||||
const int k = item_ct1.get_local_id(1) * warp_size + k0 +
|
||||
(nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V);
|
||||
|
||||
#ifdef GGML_SYCL_F16
|
||||
sycl::half2 KQ_k[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
KQ_k[j] = sycl::half2(KQ[j * nthreads + k]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
|
||||
sycl::half2 tmp[V_rows_per_thread / 2];
|
||||
dequantize_V(V + k * nb21, tmp,
|
||||
2 * i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) :
|
||||
item_ct1.get_local_id(2) % nthreads_V) *
|
||||
V_rows_per_thread);
|
||||
#pragma unroll
|
||||
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1] += tmp[i_VKQ_1]*KQ_k[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
float KQ_k[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
KQ_k[j] = KQ[j*nthreads + k];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
|
||||
sycl::float2 tmp[V_rows_per_thread/2];
|
||||
dequantize_V(V + k*nb21, tmp,
|
||||
2*i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V)*V_rows_per_thread);
|
||||
#pragma unroll
|
||||
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x() += tmp[i_VKQ_1].x()*KQ_k[j];
|
||||
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y() += tmp[i_VKQ_1].y()*KQ_k[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // GGML_SYCL_F16
|
||||
}
|
||||
}
|
||||
|
||||
if (sinks && item_ct1.get_group(1) == 0) {
|
||||
const float sink = ((const float *) sinks)[head];
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + item_ct1.get_local_id(1);
|
||||
|
||||
if (j0 + nwarps > ncols && j >= ncols) {
|
||||
break;
|
||||
}
|
||||
const float kqmax_new_j = sycl::fmax(sink, (float) KQ_max[j]);
|
||||
const float KQ_max_scale = sycl::native::exp((float) (KQ_max[j] - kqmax_new_j));
|
||||
KQ_max[j] = kqmax_new_j;
|
||||
|
||||
KQ_sum[j] = KQ_sum[j] * KQ_max_scale +
|
||||
(item_ct1.get_local_id(2) == 0 ? sycl::native::exp((float) (sink - KQ_max[j])) : 0.0f);
|
||||
#ifdef GGML_SYCL_F16
|
||||
const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||
VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||
VKQ[j][i_VKQ_0/nthreads_V].x() *= KQ_max_scale;
|
||||
VKQ[j][i_VKQ_0/nthreads_V].y() *= KQ_max_scale;
|
||||
}
|
||||
#endif // GGML_SYCL_F16
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (item_ct1.get_local_id(1) == 0) {
|
||||
KQ_max_shared[j*warp_size+item_ct1.get_local_id(2)] = -FLT_MAX / 2.0f;
|
||||
KQ_sum_shared[j*warp_size+item_ct1.get_local_id(2)] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (item_ct1.get_local_id(2) == 0) {
|
||||
KQ_max_shared[j*warp_size+item_ct1.get_local_id(1)] = KQ_max[j];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
||||
if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z())) {
|
||||
break;
|
||||
}
|
||||
|
||||
float kqmax_new = KQ_max_shared[j_VKQ*warp_size+item_ct1.get_local_id(2)];
|
||||
kqmax_new = warp_reduce_max<warp_size>(kqmax_new);
|
||||
const float kqmax_scale = sycl::native::exp((float) (KQ_max[j_VKQ] - kqmax_new));
|
||||
KQ_max[j_VKQ] = kqmax_new;
|
||||
|
||||
#ifdef GGML_SYCL_F16
|
||||
sycl::half2 * VKQ_tmp = (sycl::half2 *) KQ + item_ct1.get_local_id(1) * (V_cols_per_iter * D / 2) +
|
||||
(nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V) * (D / 2);
|
||||
|
||||
const sycl::half2 kqmax_scale_h2 = sycl::half2(kqmax_scale, kqmax_scale);
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||
VKQ[j_VKQ][i_VKQ_0/nthreads_V] *= kqmax_scale_h2;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
|
||||
const int i_VKQ =
|
||||
i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V) *
|
||||
(V_rows_per_thread / 2);
|
||||
|
||||
ggml_sycl_memcpy_1<V_rows_per_thread * sizeof(sycl::half)>(VKQ_tmp + i_VKQ,
|
||||
&VKQ[j_VKQ][i_VKQ_0 / nthreads_V]);
|
||||
}
|
||||
#else
|
||||
sycl::float2 * VKQ_tmp = (sycl::float2 *) KQ + item_ct1.get_local_id(1)*(V_cols_per_iter*D/2)
|
||||
+ (nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V)*(D/2);
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||
VKQ[j_VKQ][i_VKQ_0/nthreads_V].x() *= kqmax_scale;
|
||||
VKQ[j_VKQ][i_VKQ_0/nthreads_V].y() *= kqmax_scale;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
|
||||
const int i_VKQ = i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V)*(V_rows_per_thread/2);
|
||||
|
||||
ggml_sycl_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
|
||||
ggml_sycl_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
|
||||
}
|
||||
#endif // GGML_SYCL_F16
|
||||
|
||||
KQ_sum[j_VKQ] *= kqmax_scale;
|
||||
KQ_sum[j_VKQ] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ]);
|
||||
if (item_ct1.get_local_id(2) == 0) {
|
||||
KQ_sum_shared[j_VKQ*warp_size+item_ct1.get_local_id(1)] = KQ_sum[j_VKQ];
|
||||
}
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
|
||||
if (nthreads <= D || tid < D) {
|
||||
KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ*warp_size+item_ct1.get_local_id(2)];
|
||||
KQ_sum[j_VKQ] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ]);
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += nthreads) {
|
||||
float dst_val = 0;
|
||||
#pragma unroll
|
||||
for (int w = 0; w < nwarps; ++w) {
|
||||
#pragma unroll
|
||||
for (int v = 0; v < V_cols_per_iter; ++v) {
|
||||
dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]);
|
||||
}
|
||||
}
|
||||
if (item_ct1.get_group_range(1) == 1) {
|
||||
dst_val /= KQ_sum[j_VKQ];
|
||||
}
|
||||
dst[(((sequence * int(ne01.z()) + ic0 + j_VKQ) * ne02 + head) * item_ct1.get_group_range(1) +
|
||||
item_ct1.get_group(1)) *
|
||||
D +
|
||||
i0 + tid] = dst_val;
|
||||
}
|
||||
}
|
||||
|
||||
if (j_VKQ < ncols-1) {
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if (item_ct1.get_group_range(1) != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z()))) {
|
||||
dst_meta[((sequence * int(ne01.z()) + ic0 + tid) * ne02 + head) * item_ct1.get_group_range(1) +
|
||||
item_ct1.get_group(1)] = make_float2(KQ_max[tid], KQ_sum[tid]);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
||||
max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
|
||||
#endif // SYCL_FLASH_ATTN
|
||||
}
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif // __clang__
|
||||
|
||||
|
||||
template <int D, int cols_per_block, int type_K, int type_V, bool use_logit_softcap>
|
||||
void ggml_sycl_flash_attn_ext_vec_case_impl(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
const int warp_size = WARP_16_SIZE; //better performance than WARP_32_SIZE
|
||||
|
||||
const int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc;
|
||||
|
||||
const int nthreads = ggml_sycl_fattn_vec_get_nthreads_host(cc);
|
||||
const int nwarps = nthreads / warp_size;
|
||||
|
||||
const bool need_f16_K = type_K == GGML_TYPE_F16;
|
||||
const bool need_f16_V = type_V == GGML_TYPE_F16;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
|
||||
launch_fattn<D, cols_per_block, 1,
|
||||
flash_attn_ext_vec<D, cols_per_block, type_K, type_V,
|
||||
use_logit_softcap, warp_size>, warp_size>(
|
||||
ctx, dst, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||
}
|
||||
|
||||
template <int D, int type_K, int type_V>
|
||||
void ggml_sycl_flash_attn_ext_vec_case(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (Q->ne[1] == 1) {
|
||||
constexpr int cols_per_block = 1;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 2;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
#define DECL_FATTN_VEC_CASE(D, type_K, type_V) \
|
||||
template void ggml_sycl_flash_attn_ext_vec_case \
|
||||
<D, type_K, type_V>(ggml_backend_sycl_context & ctx, ggml_tensor * dst) \
|
||||
|
||||
#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \
|
||||
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)
|
||||
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)
|
||||
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
|
||||
|
||||
#endif // GGML_SYCL_FATTN_VEC_HPP
|
||||
225
ggml/src/ggml-sycl/fattn.cpp
Normal file
225
ggml/src/ggml-sycl/fattn.cpp
Normal file
@@ -0,0 +1,225 @@
|
||||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2025 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
|
||||
#include <sycl/sycl.hpp>
|
||||
#include "dpct/helper.hpp"
|
||||
#include "common.hpp"
|
||||
#include "fattn-common.hpp"
|
||||
#include "fattn-tile.hpp"
|
||||
#include "fattn-vec.hpp"
|
||||
#include "fattn.hpp"
|
||||
|
||||
|
||||
#define FATTN_VEC_CASE(D, type_K, type_V) \
|
||||
{ \
|
||||
const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
|
||||
const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
|
||||
if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
|
||||
ggml_sycl_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
|
||||
return; \
|
||||
} \
|
||||
} \
|
||||
|
||||
#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
|
||||
FATTN_VEC_CASE( 64, type_K, type_V) \
|
||||
FATTN_VEC_CASE(128, type_K, type_V) \
|
||||
FATTN_VEC_CASE(256, type_K, type_V) \
|
||||
|
||||
static void ggml_sycl_flash_attn_ext_vec(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * Q = dst->src[0];
|
||||
ggml_tensor * K = dst->src[1];
|
||||
ggml_tensor * V = dst->src[2];
|
||||
|
||||
#ifdef GGML_SYCL_FA_ALL_QUANTS
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
|
||||
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
|
||||
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
|
||||
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
|
||||
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
|
||||
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
||||
#else
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
||||
#endif // GGML_SYCL_FA_ALL_QUANTS
|
||||
|
||||
GGML_ABORT("Not match KV type in vec");
|
||||
}
|
||||
|
||||
// Best FlashAttention kernel for a specific GPU:
|
||||
enum best_fattn_kernel {
|
||||
BEST_FATTN_KERNEL_NONE = 0,
|
||||
BEST_FATTN_KERNEL_VEC = 100,
|
||||
BEST_FATTN_KERNEL_TILE = 200,
|
||||
};
|
||||
|
||||
static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {
|
||||
GGML_UNUSED(device);
|
||||
#ifndef SYCL_FLASH_ATTN
|
||||
GGML_UNUSED(dst);
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
#endif// SYCL_FLASH_ATTN
|
||||
|
||||
if(!g_ggml_sycl_enable_flash_attention) return BEST_FATTN_KERNEL_NONE;
|
||||
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
for (const ggml_tensor * t : {Q, K, V, mask}) {
|
||||
if (t == nullptr || ggml_is_quantized(t->type)) {
|
||||
continue;
|
||||
}
|
||||
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
|
||||
if (t->nb[i] % 16 != 0) {
|
||||
gqa_opt_applies = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch (K->ne[0]) {
|
||||
case 40:
|
||||
case 64:
|
||||
case 72:
|
||||
case 80:
|
||||
case 96:
|
||||
case 128:
|
||||
case 112:
|
||||
case 256:
|
||||
if (V->ne[0] != K->ne[0]) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
case 576:
|
||||
if (V->ne[0] != 512) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!gqa_opt_applies) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
|
||||
#ifndef GGML_SYCL_FA_ALL_QUANTS
|
||||
if (K->type != V->type) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
#endif // GGML_SYCL_FA_ALL_QUANTS
|
||||
|
||||
switch (K->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
#ifndef GGML_SYCL_FA_ALL_QUANTS
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
#endif // GGML_SYCL_FA_ALL_QUANTS
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
break;
|
||||
default:
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
|
||||
if (mask && mask->ne[2] != 1) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
|
||||
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
// Todo: Use the XMX kernel if possible:
|
||||
|
||||
// If there are no tensor cores available, use the generic tile kernel:
|
||||
if (can_use_vector_kernel) {
|
||||
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
if (Q->ne[1] == 1) {
|
||||
if (!gqa_opt_applies) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (Q->ne[1] <= 2) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
}
|
||||
return BEST_FATTN_KERNEL_TILE;
|
||||
}
|
||||
|
||||
void ggml_sycl_flash_attn_ext(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_set_device(ctx.device);
|
||||
switch (ggml_sycl_get_best_fattn_kernel(ggml_sycl_get_device(), dst)) {
|
||||
case BEST_FATTN_KERNEL_NONE:
|
||||
GGML_ABORT("Not support Flash-Attention");
|
||||
case BEST_FATTN_KERNEL_TILE:
|
||||
ggml_sycl_flash_attn_ext_tile(ctx, dst);
|
||||
break;
|
||||
case BEST_FATTN_KERNEL_VEC:
|
||||
ggml_sycl_flash_attn_ext_vec(ctx, dst);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bool ggml_sycl_flash_attn_ext_supported(int device, const ggml_tensor * dst) {
|
||||
return ggml_sycl_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
22
ggml/src/ggml-sycl/fattn.hpp
Normal file
22
ggml/src/ggml-sycl/fattn.hpp
Normal file
@@ -0,0 +1,22 @@
|
||||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2025 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
#ifndef GGML_SYCL_FATTN_HPP
|
||||
#define GGML_SYCL_FATTN_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_flash_attn_ext(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
bool ggml_sycl_flash_attn_ext_supported(int device, const ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_FATTN_HPP
|
||||
@@ -62,6 +62,8 @@ int g_ggml_sycl_disable_graph = 0;
|
||||
int g_ggml_sycl_disable_dnn = 0;
|
||||
int g_ggml_sycl_prioritize_dmmv = 0;
|
||||
int g_ggml_sycl_use_async_mem_op = 0;
|
||||
int g_ggml_sycl_enable_flash_attention = 1;
|
||||
|
||||
|
||||
static ggml_sycl_device_info ggml_sycl_init() {
|
||||
ggml_sycl_device_info info = {};
|
||||
@@ -94,11 +96,12 @@ static ggml_sycl_device_info ggml_sycl_init() {
|
||||
|
||||
info.devices[i].cc =
|
||||
100 * prop.get_major_version() + 10 * prop.get_minor_version();
|
||||
info.devices[i].nsm = prop.get_max_compute_units();
|
||||
info.devices[i].nsm = prop.get_max_compute_units() / 16; //16: Number of Xe Cores
|
||||
info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
|
||||
info.devices[i].smpbo = prop.get_local_mem_size();
|
||||
|
||||
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
|
||||
info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units();
|
||||
|
||||
}
|
||||
|
||||
for (int id = 0; id < info.device_count; ++id) {
|
||||
@@ -211,7 +214,37 @@ static void ggml_check_sycl() try {
|
||||
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
|
||||
g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
|
||||
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
|
||||
|
||||
#ifdef SYCL_FLASH_ATTN
|
||||
g_ggml_sycl_enable_flash_attention = get_sycl_env("GGML_SYCL_ENABLE_FLASH_ATTN", 1);
|
||||
#else
|
||||
g_ggml_sycl_enable_flash_attention = 0;
|
||||
#endif
|
||||
|
||||
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
|
||||
|
||||
GGML_LOG_INFO("Build with Macros:\n");
|
||||
#if defined(GGML_SYCL_FORCE_MMQ)
|
||||
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
|
||||
#else
|
||||
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n");
|
||||
#endif
|
||||
#if defined(GGML_SYCL_F16)
|
||||
GGML_LOG_INFO(" GGML_SYCL_F16: yes\n");
|
||||
#else
|
||||
GGML_LOG_INFO(" GGML_SYCL_F16: no\n");
|
||||
#endif
|
||||
#if defined(GGML_SYCL_GRAPH)
|
||||
GGML_LOG_INFO(" GGML_SYCL_GRAPH: yes\n");
|
||||
#else
|
||||
GGML_LOG_INFO(" GGML_SYCL_GRAPH: no\n");
|
||||
#endif
|
||||
#if defined(GGML_SYCL_DNNL)
|
||||
GGML_LOG_INFO(" GGML_SYCL_DNNL: yes\n");
|
||||
#else
|
||||
GGML_LOG_INFO(" GGML_SYCL_DNNL: no\n");
|
||||
#endif
|
||||
|
||||
GGML_LOG_INFO("Running with Environment Variables:\n");
|
||||
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
|
||||
@@ -226,16 +259,12 @@ static void ggml_check_sycl() try {
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
|
||||
#endif
|
||||
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
|
||||
GGML_LOG_INFO("Build with Macros:\n");
|
||||
#if defined(GGML_SYCL_FORCE_MMQ)
|
||||
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
|
||||
|
||||
#ifdef SYCL_FLASH_ATTN
|
||||
GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d\n", g_ggml_sycl_enable_flash_attention);
|
||||
#else
|
||||
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n");
|
||||
#endif
|
||||
#if defined(GGML_SYCL_F16)
|
||||
GGML_LOG_INFO(" GGML_SYCL_F16: yes\n");
|
||||
#else
|
||||
GGML_LOG_INFO(" GGML_SYCL_F16: no\n");
|
||||
GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d disabled by compile flag\n",
|
||||
g_ggml_sycl_enable_flash_attention);
|
||||
#endif
|
||||
|
||||
/* NOT REMOVE, keep it for next optimize for XMX.
|
||||
@@ -3012,7 +3041,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
||||
|
||||
}
|
||||
#if GGML_SYCL_DNNL
|
||||
// oneDNN handles strided data and does not need overhead of get_to_fp16_nc_sycl
|
||||
// oneDNN handles strided data and does not need overhead of ggml_get_to_fp16_nc_sycl
|
||||
const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
|
||||
src1_f16_alloc.alloc(ne_src1);
|
||||
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
|
||||
@@ -3021,7 +3050,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
||||
# else
|
||||
const int64_t ne_src1 = ggml_nelements(src1);
|
||||
src1_f16_alloc.alloc(ne_src1);
|
||||
const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
|
||||
const to_fp16_nc_sycl_t to_fp16_nc_sycl = ggml_get_to_fp16_nc_sycl(src1->type);
|
||||
GGML_ASSERT(to_fp16_nc_sycl != nullptr);
|
||||
to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
|
||||
#endif
|
||||
@@ -4158,6 +4187,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
case GGML_OP_ARANGE:
|
||||
ggml_sycl_arange(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
ggml_sycl_flash_attn_ext(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -4862,6 +4894,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ARANGE:
|
||||
return op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
return ggml_sycl_flash_attn_ext_supported(device, op);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -73,4 +73,7 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
|
||||
#define MUL_MAT_SRC1_COL_STRIDE 128
|
||||
|
||||
#define QK_WARP_SIZE 32
|
||||
#define WARP_32_SIZE 32
|
||||
#define WARP_16_SIZE 16
|
||||
|
||||
#endif // GGML_SYCL_PRESETS_HPP
|
||||
|
||||
@@ -102,7 +102,7 @@ static void soft_max_f32(const float * x,
|
||||
max_val = sycl::max(max_val, val);
|
||||
}
|
||||
// find the max value in the block
|
||||
max_val = warp_reduce_max(max_val);
|
||||
max_val = warp_reduce_max<WARP_SIZE>(max_val);
|
||||
|
||||
if (block_size > WARP_SIZE) {
|
||||
if (warp_id == 0) {
|
||||
@@ -116,7 +116,7 @@ static void soft_max_f32(const float * x,
|
||||
item_ct1.barrier();
|
||||
|
||||
max_val = buf_iw[lane_id];
|
||||
max_val = warp_reduce_max(max_val);
|
||||
max_val = warp_reduce_max<WARP_SIZE>(max_val);
|
||||
}
|
||||
float tmp = 0.0f; // partial sum
|
||||
|
||||
@@ -133,7 +133,7 @@ static void soft_max_f32(const float * x,
|
||||
vals[col] = val;
|
||||
}
|
||||
// find the sum of exps in the block
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
tmp = warp_reduce_sum<WARP_SIZE>(tmp);
|
||||
if (block_size > WARP_SIZE) {
|
||||
item_ct1.barrier();
|
||||
if (warp_id == 0) {
|
||||
@@ -153,7 +153,7 @@ static void soft_max_f32(const float * x,
|
||||
for (size_t i = 1; i < nreduce; i += 1) {
|
||||
tmp += buf_iw[lane_id + i * WARP_SIZE];
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
tmp = warp_reduce_sum<WARP_SIZE>(tmp);
|
||||
}
|
||||
if (sinks) {
|
||||
tmp += sycl::native::exp(sinks[i02] - max_val);
|
||||
@@ -191,7 +191,7 @@ static void soft_max_back_f32(const float *grad, const float *dstf, float *dst,
|
||||
dgf_dot += dstf[col]*grad[col];
|
||||
}
|
||||
|
||||
dgf_dot = warp_reduce_sum(dgf_dot);
|
||||
dgf_dot = warp_reduce_sum<WARP_SIZE>(dgf_dot);
|
||||
|
||||
for (int col = tid; col < ncols; col += WARP_SIZE) {
|
||||
dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.hpp"
|
||||
|
||||
DECL_FATTN_TILE_CASE(112, 112);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.hpp"
|
||||
|
||||
DECL_FATTN_TILE_CASE(128, 128);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.hpp"
|
||||
|
||||
DECL_FATTN_TILE_CASE(256, 256);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.hpp"
|
||||
|
||||
DECL_FATTN_TILE_CASE(40, 40);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.hpp"
|
||||
|
||||
DECL_FATTN_TILE_CASE(576, 512);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.hpp"
|
||||
|
||||
DECL_FATTN_TILE_CASE(64, 64);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.hpp"
|
||||
|
||||
DECL_FATTN_TILE_CASE(72, 72);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.hpp"
|
||||
|
||||
DECL_FATTN_TILE_CASE(80, 80);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.hpp"
|
||||
|
||||
DECL_FATTN_TILE_CASE(96, 96);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_F16);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_F16);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_F16);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.hpp"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
|
||||
@@ -650,6 +650,19 @@ static __dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int *v, const int *u,
|
||||
return d8_0*d8_1 * sumi;
|
||||
}
|
||||
|
||||
template <typename T, int vdr>
|
||||
static __dpct_inline__ T vec_dot_q8_0_q8_1_impl(const int * v, const int * u, const T & d8_0, const T & d8_1) {
|
||||
int sumi = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vdr; ++i) {
|
||||
// SIMD dot product of quantized values
|
||||
sumi = ggml_sycl_dp4a(v[i], u[i], sumi);
|
||||
}
|
||||
|
||||
return d8_0*d8_1 * ((T) sumi);
|
||||
}
|
||||
|
||||
template <int vdr>
|
||||
static __dpct_inline__ float vec_dot_q8_1_q8_1_impl(const int *v, const int *u,
|
||||
const sycl::half2 &dm8,
|
||||
|
||||
@@ -744,6 +744,7 @@ struct vk_device_struct {
|
||||
|
||||
// [src/dst 0=fp32,1=fp16]
|
||||
vk_pipeline pipeline_exp[2];
|
||||
vk_pipeline pipeline_elu[2];
|
||||
vk_pipeline pipeline_gelu[2];
|
||||
vk_pipeline pipeline_gelu_erf[2];
|
||||
vk_pipeline pipeline_gelu_quick[2];
|
||||
@@ -762,6 +763,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_ceil[2];
|
||||
vk_pipeline pipeline_floor[2];
|
||||
vk_pipeline pipeline_trunc[2];
|
||||
vk_pipeline pipeline_sgn[2];
|
||||
|
||||
vk_pipeline pipeline_add1_f16_f16;
|
||||
vk_pipeline pipeline_add1_f16_f32;
|
||||
@@ -4373,6 +4375,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
CREATE_UNARY(elu)
|
||||
CREATE_UNARY(gelu)
|
||||
CREATE_UNARY(gelu_erf)
|
||||
CREATE_UNARY(gelu_quick)
|
||||
@@ -4391,6 +4394,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_UNARY(ceil)
|
||||
CREATE_UNARY(floor)
|
||||
CREATE_UNARY(trunc)
|
||||
CREATE_UNARY(sgn)
|
||||
#undef CREATE_UNARY
|
||||
|
||||
#define CREATE_UNARY_RTE(name) \
|
||||
@@ -9241,6 +9245,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
switch (ggml_get_unary_op(dst)) {
|
||||
case GGML_UNARY_OP_EXP:
|
||||
return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_ELU:
|
||||
return ctx->device->pipeline_elu[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_SILU:
|
||||
return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_GELU:
|
||||
@@ -9277,6 +9283,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_SGN:
|
||||
return ctx->device->pipeline_sgn[dst->type == GGML_TYPE_F16];
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -12852,6 +12860,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
}
|
||||
|
||||
switch (ggml_get_unary_op(node)) {
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_GELU:
|
||||
@@ -12870,6 +12879,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
case GGML_UNARY_OP_SGN:
|
||||
ggml_vk_unary(ctx, compute_ctx, src0, node);
|
||||
break;
|
||||
case GGML_UNARY_OP_XIELU:
|
||||
@@ -13248,6 +13258,10 @@ static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, g
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
||||
vk_buffer buf = buf_ctx->dev_buffer;
|
||||
|
||||
if (size == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t val32 = (uint32_t)value * 0x01010101;
|
||||
ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size);
|
||||
}
|
||||
@@ -13257,6 +13271,10 @@ static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
||||
vk_buffer buf = buf_ctx->dev_buffer;
|
||||
|
||||
if (size == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
|
||||
}
|
||||
|
||||
@@ -13264,12 +13282,20 @@ static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, cons
|
||||
VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
||||
|
||||
if (size == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
vk_buffer buf = buf_ctx->dev_buffer;
|
||||
|
||||
ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
|
||||
}
|
||||
|
||||
static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
if (ggml_nbytes(src) == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ggml_backend_buffer_is_vk(src->buffer)) {
|
||||
ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
@@ -13459,6 +13485,10 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
|
||||
|
||||
if (size == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
|
||||
|
||||
vk_context cpy_ctx;
|
||||
@@ -13502,6 +13532,10 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
|
||||
|
||||
if (size == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
|
||||
|
||||
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
||||
@@ -13528,9 +13562,14 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
|
||||
}
|
||||
|
||||
static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()");
|
||||
VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async(" << src << " -> " << dst << ", size=" << ggml_nbytes(src) << ")");
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend_dst->context;
|
||||
|
||||
// Skip zero-size tensors
|
||||
if (ggml_nbytes(src) == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (dst->buffer->buft != ggml_backend_vk_get_default_buffer_type(backend_dst)) {
|
||||
return false;
|
||||
}
|
||||
@@ -14951,6 +14990,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
@@ -14969,6 +15009,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
case GGML_UNARY_OP_SGN:
|
||||
return ggml_is_contiguous(op->src[0]) &&
|
||||
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||
@@ -16074,6 +16115,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||
case GGML_UNARY_OP_EXP:
|
||||
tensor_clone = ggml_exp(ggml_ctx, src_clone[0]);
|
||||
break;
|
||||
case GGML_UNARY_OP_ELU:
|
||||
tensor_clone = ggml_elu(ggml_ctx, src_clone[0]);
|
||||
break;
|
||||
case GGML_UNARY_OP_SILU:
|
||||
tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
|
||||
break;
|
||||
@@ -16132,6 +16176,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
|
||||
break;
|
||||
case GGML_UNARY_OP_SGN:
|
||||
tensor_clone = ggml_sgn(ggml_ctx, src_clone[0]);
|
||||
break;
|
||||
default:
|
||||
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
||||
GGML_ABORT("fatal error");
|
||||
|
||||
27
ggml/src/ggml-vulkan/vulkan-shaders/elu.comp
Normal file
27
ggml/src/ggml-vulkan/vulkan-shaders/elu.comp
Normal file
@@ -0,0 +1,27 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
float x = float(data_a[i]);
|
||||
|
||||
if (x < 0.0f) {
|
||||
x = exp(x) - 1;
|
||||
}
|
||||
|
||||
data_d[i] = D_TYPE(x);
|
||||
}
|
||||
@@ -377,6 +377,7 @@ void main() {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
barrier();
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
const uint row_i = dc + cm_col * TN + col + store_c;
|
||||
if (row_i >= _ne1) break;
|
||||
@@ -387,6 +388,7 @@ void main() {
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
}
|
||||
#else
|
||||
@@ -404,18 +406,22 @@ void main() {
|
||||
// Full coopMat is within bounds, but stride_d is not aligned
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
|
||||
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
|
||||
// Partial coopMat is within bounds
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
|
||||
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
}
|
||||
controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
21
ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp
Normal file
21
ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp
Normal file
@@ -0,0 +1,21 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
data_d[i] = D_TYPE(sign(float(data_a[i])));
|
||||
}
|
||||
@@ -867,8 +867,12 @@ void process_shaders() {
|
||||
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("elu_f16", "elu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("elu_f32", "elu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("xielu_f16", "xielu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("xielu_f32", "xielu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("sgn_f16", "sgn.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("sgn_f32", "sgn.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-opt.h"
|
||||
#include "gguf.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
@@ -440,19 +441,30 @@ extern "C" {
|
||||
|
||||
LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
|
||||
|
||||
typedef void (*llama_model_set_tensor_data_t)(struct ggml_tensor * tensor, void * userdata);
|
||||
|
||||
// Create a new model from GGUF metadata as well as a function to set the tensor data
|
||||
// - tensors are created as GGML_TYPE_F32 by default,
|
||||
// override by adding a tensor with the same name but a different name to the context
|
||||
LLAMA_API struct llama_model * llama_model_init_from_user(
|
||||
struct gguf_context * metadata,
|
||||
llama_model_set_tensor_data_t set_tensor_data, // function to initialize tensor data with
|
||||
void * set_tensor_data_ud, // userdata for function
|
||||
struct llama_model_params params);
|
||||
|
||||
DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file(
|
||||
const char * path_model,
|
||||
struct llama_model_params params),
|
||||
"use llama_model_load_from_file instead");
|
||||
|
||||
// Load the model from a file
|
||||
// Load a model from a file
|
||||
// If the file is split into multiple parts, the file name must follow this pattern: <name>-%05d-of-%05d.gguf
|
||||
// If the split file name does not follow this pattern, use llama_model_load_from_splits
|
||||
LLAMA_API struct llama_model * llama_model_load_from_file(
|
||||
const char * path_model,
|
||||
struct llama_model_params params);
|
||||
|
||||
// Load the model from multiple splits (support custom naming scheme)
|
||||
// Load a model from multiple splits (support custom naming scheme)
|
||||
// The paths must be in the correct order
|
||||
LLAMA_API struct llama_model * llama_model_load_from_splits(
|
||||
const char ** paths,
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
{%- set messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
{%- if tools -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "You can use the following tools: <|tool_list_start|>[" -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "List of tools: <|tool_list_start|>[" -%}
|
||||
{%- for tool in tools -%}
|
||||
{%- if tool is not string -%}
|
||||
{%- set tool = tool | tojson -%}
|
||||
@@ -17,7 +17,6 @@
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + "]<|tool_list_end|>" -%}
|
||||
{{- '**IMPORTANT**: The syntax for calling the tools is: <|tool_call_start|>JSON tool call goes here<|tool_call_end|>. Please only call tools in the specified manner.' -}}
|
||||
{%- endif -%}
|
||||
{%- if ns.system_prompt -%}
|
||||
{{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}}
|
||||
@@ -30,18 +29,9 @@
|
||||
{%- endif -%}
|
||||
{%- if message["role"] == "tool" -%}
|
||||
{%- set content = "<|tool_response_start|>" + content + "<|tool_response_end|>" -%}
|
||||
{%- elif message["role"] == "assistant" -%}
|
||||
{%- if message.tool_calls %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if tool_call.function %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '\n<|tool_call_start|>\n{"name": "' + tool_call.name + '", "arguments": ' + (tool_call.arguments if tool_call.arguments is string else tool_call.arguments | tojson) + '}\n<|tool_call_end|>\n' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- endif -%}
|
||||
{{- content + "<|im_end|>\n" -}}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{- "<|im_start|>assistant\n" -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
@@ -1,37 +0,0 @@
|
||||
{{- bos_token -}}
|
||||
{%- set system_prompt = "" -%}
|
||||
{%- set ns = namespace(system_prompt="") -%}
|
||||
{%- if messages[0]["role"] == "system" -%}
|
||||
{%- set ns.system_prompt = messages[0]["content"] -%}
|
||||
{%- set messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
{%- if tools -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "List of tools: <|tool_list_start|>[" -%}
|
||||
{%- for tool in tools -%}
|
||||
{%- if tool is not string -%}
|
||||
{%- set tool = tool | tojson -%}
|
||||
{%- endif -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + tool -%}
|
||||
{%- if not loop.last -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + ", " -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + "]<|tool_list_end|>" -%}
|
||||
{%- endif -%}
|
||||
{%- if ns.system_prompt -%}
|
||||
{{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}}
|
||||
{%- endif -%}
|
||||
{%- for message in messages -%}
|
||||
{{- "<|im_start|>" + message["role"] + "\n" -}}
|
||||
{%- set content = message["content"] -%}
|
||||
{%- if content is not string -%}
|
||||
{%- set content = content | tojson -%}
|
||||
{%- endif -%}
|
||||
{%- if message["role"] == "tool" -%}
|
||||
{%- set content = "<|tool_response_start|>" + content + "<|tool_response_end|>" -%}
|
||||
{%- endif -%}
|
||||
{{- content + "<|im_end|>\n" -}}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{- "<|im_start|>assistant\n" -}}
|
||||
{%- endif -%}
|
||||
18
scripts/git-bisect-run.sh
Executable file
18
scripts/git-bisect-run.sh
Executable file
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
cmake_args=()
|
||||
llama_results_args=()
|
||||
|
||||
for arg in "${@}"; do
|
||||
if [[ "$arg" == -D* ]]; then
|
||||
cmake_args+=("$arg")
|
||||
else
|
||||
llama_results_args+=("$arg")
|
||||
fi
|
||||
done
|
||||
|
||||
dir="build-bisect"
|
||||
rm -rf ${dir} > /dev/null
|
||||
cmake -B ${dir} -S . ${cmake_args} > /dev/null
|
||||
cmake --build ${dir} -t llama-results -j $(nproc) > /dev/null
|
||||
${dir}/bin/llama-results "${llama_results_args[@]}"
|
||||
19
scripts/git-bisect.sh
Executable file
19
scripts/git-bisect.sh
Executable file
@@ -0,0 +1,19 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
if [ $# -lt 2 ]; then
|
||||
echo "usage: ./scripts/git-bisect.sh <commit_bad> <commit_good> [additional arguments]"
|
||||
echo " additional arguments: passed to CMake if they start with \"-D\", to llama-results otherwise"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
set -e
|
||||
set -x
|
||||
|
||||
commit_bad=$1
|
||||
commit_good=$2
|
||||
script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
git checkout ${commit_good}
|
||||
${script_dir}/git-bisect-run.sh --output results.gguf "${@:3}"
|
||||
git bisect start ${commit_bad} ${commit_good}
|
||||
git bisect run ${script_dir}/git-bisect-run.sh --output results.gguf --check "${@:3}"
|
||||
git bisect reset
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_CLIP, "clip" }, // dummy, only used by llama-quantize
|
||||
@@ -2786,6 +2787,15 @@ std::string LLM_TN_IMPL::str() const {
|
||||
return name;
|
||||
}
|
||||
|
||||
std::vector<llm_arch> llm_arch_all() {
|
||||
std::vector<llm_arch> ret;
|
||||
ret.reserve(LLM_ARCH_NAMES.size());
|
||||
for (const auto & [arch, _] : LLM_ARCH_NAMES) {
|
||||
ret.push_back(arch);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
const char * llm_arch_name(llm_arch arch) {
|
||||
auto it = LLM_ARCH_NAMES.find(arch);
|
||||
if (it == LLM_ARCH_NAMES.end()) {
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
//
|
||||
// gguf constants (sync with gguf.py)
|
||||
@@ -608,6 +609,8 @@ struct llm_tensor_info {
|
||||
ggml_op op;
|
||||
};
|
||||
|
||||
std::vector<llm_arch> llm_arch_all();
|
||||
|
||||
const char * llm_arch_name(llm_arch arch);
|
||||
|
||||
llm_arch llm_arch_from_string(const std::string & name);
|
||||
|
||||
@@ -1158,6 +1158,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
|
||||
{
|
||||
//const auto t_start_us = ggml_time_us();
|
||||
|
||||
// FIXME this call causes a crash if any model inputs were not used in the graph and were therefore not allocated
|
||||
res->set_inputs(&ubatch);
|
||||
|
||||
//LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
|
||||
|
||||
@@ -509,6 +509,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||
float * data = (float *) cross_kq_mask->data;
|
||||
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first");
|
||||
for (int j = 0; j < n_enc; ++j) {
|
||||
float f = -INFINITY;
|
||||
|
||||
@@ -1161,7 +1162,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
int64_t n_expert_used,
|
||||
llm_ffn_op_type type_op,
|
||||
bool norm_w,
|
||||
bool scale_w,
|
||||
float w_scale,
|
||||
llama_expert_gating_func_type gating_op,
|
||||
int il,
|
||||
@@ -1178,7 +1178,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
n_expert_used,
|
||||
type_op,
|
||||
norm_w,
|
||||
scale_w,
|
||||
w_scale,
|
||||
gating_op,
|
||||
il,
|
||||
@@ -1202,7 +1201,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
int64_t n_expert_used,
|
||||
llm_ffn_op_type type_op,
|
||||
bool norm_w,
|
||||
bool scale_w,
|
||||
float w_scale,
|
||||
llama_expert_gating_func_type gating_op,
|
||||
int il,
|
||||
@@ -1330,7 +1328,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
|
||||
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
|
||||
}
|
||||
if (scale_w) {
|
||||
if (w_scale != 0.0f && w_scale != 1.0f) {
|
||||
weights = ggml_scale(ctx0, weights, w_scale);
|
||||
cb(weights, "ffn_moe_weights_scaled", il);
|
||||
}
|
||||
@@ -1607,6 +1605,7 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
||||
// this need to be 1x1xN for broadcasting
|
||||
cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
|
||||
ggml_set_input(cur);
|
||||
ggml_set_name(cur, "attn_scale");
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
|
||||
|
||||
@@ -810,7 +810,6 @@ struct llm_graph_context {
|
||||
int64_t n_expert_used,
|
||||
llm_ffn_op_type type_op,
|
||||
bool norm_w,
|
||||
bool scale_w,
|
||||
float w_scale,
|
||||
llama_expert_gating_func_type gating_op,
|
||||
int il,
|
||||
@@ -832,7 +831,6 @@ struct llm_graph_context {
|
||||
int64_t n_expert_used,
|
||||
llm_ffn_op_type type_op,
|
||||
bool norm_w,
|
||||
bool scale_w,
|
||||
float w_scale,
|
||||
llama_expert_gating_func_type gating_op,
|
||||
int il,
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
#include "llama-model-loader.h"
|
||||
|
||||
#include "ggml-alloc.h"
|
||||
#include "ggml.h"
|
||||
#include "gguf.h"
|
||||
#include "llama-hparams.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cinttypes>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <future>
|
||||
#include <regex>
|
||||
|
||||
static const size_t kiB = 1024;
|
||||
static const size_t MiB = 1024*kiB;
|
||||
@@ -263,7 +268,7 @@ namespace GGUFMeta {
|
||||
template<typename T>
|
||||
typename std::enable_if<std::is_integral<T>::value, bool>::type
|
||||
llama_model_loader::get_arr_n(const std::string & key, T & result, bool required) {
|
||||
const int kid = gguf_find_key(meta.get(), key.c_str());
|
||||
const int kid = gguf_find_key(metadata, key.c_str());
|
||||
|
||||
if (kid < 0) {
|
||||
if (required) {
|
||||
@@ -273,7 +278,7 @@ namespace GGUFMeta {
|
||||
}
|
||||
|
||||
struct GGUFMeta::ArrayInfo arr_info =
|
||||
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
|
||||
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(metadata, kid);
|
||||
|
||||
|
||||
result = arr_info.length;
|
||||
@@ -290,7 +295,7 @@ namespace GGUFMeta {
|
||||
|
||||
template<typename T>
|
||||
bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) {
|
||||
const gguf_context * ctx = meta.get();
|
||||
const gguf_context * ctx = metadata;
|
||||
const int kid = gguf_find_key(ctx, key.c_str());
|
||||
|
||||
if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
|
||||
@@ -331,7 +336,7 @@ namespace GGUFMeta {
|
||||
|
||||
template<typename T, size_t N_MAX>
|
||||
bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) {
|
||||
const gguf_context * ctx = meta.get();
|
||||
const gguf_context * ctx = metadata;
|
||||
const int kid = gguf_find_key(ctx, key.c_str());
|
||||
|
||||
if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
|
||||
@@ -393,7 +398,7 @@ namespace GGUFMeta {
|
||||
const struct llama_model_kv_override * override =
|
||||
it != kv_overrides.end() ? &it->second : nullptr;
|
||||
|
||||
const bool found = GGUFMeta::GKV<T>::set(meta.get(), key, result, override);
|
||||
const bool found = GGUFMeta::GKV<T>::set(metadata, key, result, override);
|
||||
|
||||
if (required && !found) {
|
||||
throw std::runtime_error(format("key not found in model: %s", key.c_str()));
|
||||
@@ -427,7 +432,7 @@ namespace GGUFMeta {
|
||||
// get array of n <= N_MAX elements, or a single element repeated n times
|
||||
template<typename T, size_t N_MAX>
|
||||
bool llama_model_loader::get_key_or_arr(const std::string & key, std::array<T, N_MAX> & result, uint32_t n, bool required) {
|
||||
const int kid = gguf_find_key(meta.get(), key.c_str());
|
||||
const int kid = gguf_find_key(metadata, key.c_str());
|
||||
|
||||
if (kid < 0) {
|
||||
if (required) {
|
||||
@@ -440,9 +445,9 @@ namespace GGUFMeta {
|
||||
throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str()));
|
||||
}
|
||||
|
||||
if (gguf_get_kv_type(meta.get(), kid) == GGUF_TYPE_ARRAY) {
|
||||
if (gguf_get_kv_type(metadata, kid) == GGUF_TYPE_ARRAY) {
|
||||
struct GGUFMeta::ArrayInfo arr_info =
|
||||
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
|
||||
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(metadata, kid);
|
||||
|
||||
if (n != arr_info.length) {
|
||||
throw std::runtime_error(format("key %s has wrong array length; expected %u, got %u", key.c_str(), n, (uint32_t) arr_info.length));
|
||||
@@ -473,7 +478,7 @@ namespace GGUFMeta {
|
||||
bool llama_model_loader::get_key_or_arr(enum llm_kv kid, uint32_t & result, bool required) {
|
||||
const std::string key = llm_kv(kid);
|
||||
|
||||
const int id = gguf_find_key(meta.get(), key.c_str());
|
||||
const int id = gguf_find_key(metadata, key.c_str());
|
||||
|
||||
if (id < 0) {
|
||||
if (required) {
|
||||
@@ -483,7 +488,7 @@ namespace GGUFMeta {
|
||||
}
|
||||
|
||||
// throw and error if type is an array
|
||||
if (gguf_get_kv_type(meta.get(), id) == GGUF_TYPE_ARRAY) {
|
||||
if (gguf_get_kv_type(metadata, id) == GGUF_TYPE_ARRAY) {
|
||||
if (required) {
|
||||
throw std::runtime_error(format("expected scalar, found array for key: %s", key.c_str()));
|
||||
}
|
||||
@@ -500,6 +505,9 @@ namespace GGUFMeta {
|
||||
|
||||
|
||||
llama_model_loader::llama_model_loader(
|
||||
struct gguf_context * meta,
|
||||
llama_model_set_tensor_data_t set_tensor_data,
|
||||
void * set_tensor_data_ud,
|
||||
const std::string & fname,
|
||||
std::vector<std::string> & splits,
|
||||
bool use_mmap,
|
||||
@@ -507,7 +515,8 @@ llama_model_loader::llama_model_loader(
|
||||
bool check_tensors,
|
||||
bool no_alloc,
|
||||
const llama_model_kv_override * param_overrides_p,
|
||||
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) {
|
||||
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p)
|
||||
: metadata(meta), set_tensor_data(set_tensor_data), set_tensor_data_ud(set_tensor_data_ud) {
|
||||
int trace = 0;
|
||||
if (getenv("LLAMA_TRACE")) {
|
||||
trace = atoi(getenv("LLAMA_TRACE"));
|
||||
@@ -521,136 +530,142 @@ llama_model_loader::llama_model_loader(
|
||||
|
||||
tensor_buft_overrides = param_tensor_buft_overrides_p;
|
||||
|
||||
// Load the main GGUF
|
||||
struct ggml_context * ctx = NULL;
|
||||
struct gguf_init_params params = {
|
||||
/*.no_alloc = */ true,
|
||||
/*.ctx = */ &ctx,
|
||||
};
|
||||
if (!fname.empty()) {
|
||||
// Load the main GGUF
|
||||
struct ggml_context * ctx = NULL;
|
||||
struct gguf_init_params params = {
|
||||
/*.no_alloc = */ true,
|
||||
/*.ctx = */ &ctx,
|
||||
};
|
||||
|
||||
meta.reset(gguf_init_from_file(fname.c_str(), params));
|
||||
if (!meta) {
|
||||
throw std::runtime_error(format("%s: failed to load model from %s", __func__, fname.c_str()));
|
||||
}
|
||||
|
||||
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
|
||||
llm_kv = LLM_KV(llm_arch_from_string(arch_name));
|
||||
|
||||
files.emplace_back(new llama_file(fname.c_str(), "rb", use_direct_io));
|
||||
contexts.emplace_back(ctx);
|
||||
|
||||
if (use_mmap && use_direct_io) {
|
||||
if (files.back()->has_direct_io()) {
|
||||
LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__);
|
||||
use_mmap = false;
|
||||
} else {
|
||||
LLAMA_LOG_WARN("%s: direct I/O is not available, using mmap\n", __func__);
|
||||
use_direct_io = false;
|
||||
|
||||
// reopen file using std::fopen for mmap
|
||||
files.pop_back();
|
||||
files.emplace_back(new llama_file(fname.c_str(), "rb", false));
|
||||
}
|
||||
}
|
||||
|
||||
// Save tensors data offset of the main file.
|
||||
// For subsidiary files, `meta` tensor data offset must not be used,
|
||||
// so we build a unified tensors index for weights.
|
||||
for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
|
||||
std::string tensor_name = std::string(cur->name);
|
||||
// make sure there is no duplicated tensor names
|
||||
if (weights_map.find(tensor_name) != weights_map.end()) {
|
||||
throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur)));
|
||||
}
|
||||
n_elements += ggml_nelements(cur);
|
||||
n_bytes += ggml_nbytes(cur);
|
||||
weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), 0, meta.get(), cur));
|
||||
}
|
||||
uint16_t n_split = 0;
|
||||
get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false);
|
||||
|
||||
// Load additional GGML contexts
|
||||
if (n_split > 1) {
|
||||
// make sure the main file is loaded first
|
||||
uint16_t idx = 0;
|
||||
const std::string kv_split_no = llm_kv(LLM_KV_SPLIT_NO);
|
||||
get_key(kv_split_no, idx);
|
||||
if (idx != 0) {
|
||||
throw std::runtime_error(format("illegal split file idx: %d (file: %s), model must be loaded with the first split", idx, fname.c_str()));
|
||||
metadata_ptr.reset(gguf_init_from_file(fname.c_str(), params));
|
||||
metadata = metadata_ptr.get();
|
||||
if (metadata == nullptr) {
|
||||
throw std::runtime_error(format("%s: failed to load model from %s", __func__, fname.c_str()));
|
||||
}
|
||||
|
||||
// generate list of splits if needed
|
||||
if (splits.empty()) {
|
||||
splits = llama_get_list_splits(fname, idx, n_split);
|
||||
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
|
||||
llm_kv = LLM_KV(llm_arch_from_string(arch_name));
|
||||
|
||||
files.emplace_back(new llama_file(fname.c_str(), "rb", use_direct_io));
|
||||
contexts.emplace_back(ctx);
|
||||
|
||||
if (use_mmap && use_direct_io) {
|
||||
if (files.back()->has_direct_io()) {
|
||||
LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__);
|
||||
use_mmap = false;
|
||||
} else {
|
||||
LLAMA_LOG_WARN("%s: direct I/O is not available, using mmap\n", __func__);
|
||||
use_direct_io = false;
|
||||
|
||||
// reopen file using std::fopen for mmap
|
||||
files.pop_back();
|
||||
files.emplace_back(new llama_file(fname.c_str(), "rb", false));
|
||||
}
|
||||
}
|
||||
|
||||
// in case user give a custom list of splits, check if it matches the expected number
|
||||
if (n_split != (uint16_t)splits.size()) {
|
||||
throw std::runtime_error(format("invalid split count, given: %zu splits, but expected %d", splits.size(), n_split));
|
||||
// Save tensors data offset of the main file.
|
||||
// For subsidiary files, `meta` tensor data offset must not be used,
|
||||
// so we build a unified tensors index for weights.
|
||||
for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
|
||||
std::string tensor_name = std::string(cur->name);
|
||||
// make sure there is no duplicated tensor names
|
||||
if (weights_map.find(tensor_name) != weights_map.end()) {
|
||||
throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur)));
|
||||
}
|
||||
n_elements += ggml_nelements(cur);
|
||||
n_bytes += ggml_nbytes(cur);
|
||||
weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), 0, metadata, cur));
|
||||
}
|
||||
uint16_t n_split = 0;
|
||||
get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false);
|
||||
|
||||
if (trace > 0) {
|
||||
LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split);
|
||||
}
|
||||
|
||||
// load other splits
|
||||
for (idx = 1; idx < n_split; idx++) {
|
||||
const char * fname_split = splits[idx].c_str();
|
||||
|
||||
struct gguf_init_params split_params = {
|
||||
/*.no_alloc = */ true,
|
||||
/*.ctx = */ &ctx,
|
||||
};
|
||||
gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) };
|
||||
if (!ctx_gguf) {
|
||||
throw std::runtime_error(format("%s: failed to load GGUF split from %s", __func__, fname_split));
|
||||
// Load additional GGML contexts
|
||||
if (n_split > 1) {
|
||||
// make sure the main file is loaded first
|
||||
uint16_t idx = 0;
|
||||
const std::string kv_split_no = llm_kv(LLM_KV_SPLIT_NO);
|
||||
get_key(kv_split_no, idx);
|
||||
if (idx != 0) {
|
||||
throw std::runtime_error(format("illegal split file idx: %d (file: %s), model must be loaded with the first split", idx, fname.c_str()));
|
||||
}
|
||||
|
||||
// check idx
|
||||
// generate list of splits if needed
|
||||
if (splits.empty()) {
|
||||
splits = llama_get_list_splits(fname, idx, n_split);
|
||||
}
|
||||
|
||||
// in case user give a custom list of splits, check if it matches the expected number
|
||||
if (n_split != (uint16_t)splits.size()) {
|
||||
throw std::runtime_error(format("invalid split count, given: %zu splits, but expected %d", splits.size(), n_split));
|
||||
}
|
||||
|
||||
if (trace > 0) {
|
||||
LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split);
|
||||
}
|
||||
|
||||
// load other splits
|
||||
for (idx = 1; idx < n_split; idx++) {
|
||||
const char * fname_split = splits[idx].c_str();
|
||||
|
||||
struct gguf_init_params split_params = {
|
||||
/*.no_alloc = */ true,
|
||||
/*.ctx = */ &ctx,
|
||||
};
|
||||
gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) };
|
||||
if (!ctx_gguf) {
|
||||
throw std::runtime_error(format("%s: failed to load GGUF split from %s", __func__, fname_split));
|
||||
}
|
||||
|
||||
// check idx
|
||||
{
|
||||
const int kid = gguf_find_key(ctx_gguf.get(), kv_split_no.c_str());
|
||||
if (kid < 0) {
|
||||
throw std::runtime_error(format("missing key %s in GGUF split %s", kv_split_no.c_str(), fname_split));
|
||||
}
|
||||
int idx_gguf = gguf_get_val_u16(ctx_gguf.get(), kid);
|
||||
if (idx_gguf != idx) {
|
||||
throw std::runtime_error(format("invalid split file idx: %d (file: %s), expected %d", idx_gguf, fname_split, idx));
|
||||
}
|
||||
}
|
||||
|
||||
files.emplace_back(new llama_file(fname_split, "rb", use_direct_io));
|
||||
contexts.emplace_back(ctx);
|
||||
|
||||
// Save tensors data offset info of the shard.
|
||||
for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
|
||||
std::string tensor_name = std::string(cur->name);
|
||||
// make sure there is no duplicated tensor names
|
||||
if (weights_map.find(tensor_name) != weights_map.end()) {
|
||||
throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur)));
|
||||
}
|
||||
n_elements += ggml_nelements(cur);
|
||||
n_bytes += ggml_nbytes(cur);
|
||||
weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), idx, ctx_gguf.get(), cur));
|
||||
}
|
||||
}
|
||||
|
||||
get_key(llm_kv(LLM_KV_SPLIT_TENSORS_COUNT), n_tensors);
|
||||
|
||||
// sanity check
|
||||
{
|
||||
const int kid = gguf_find_key(ctx_gguf.get(), kv_split_no.c_str());
|
||||
if (kid < 0) {
|
||||
throw std::runtime_error(format("missing key %s in GGUF split %s", kv_split_no.c_str(), fname_split));
|
||||
}
|
||||
int idx_gguf = gguf_get_val_u16(ctx_gguf.get(), kid);
|
||||
if (idx_gguf != idx) {
|
||||
throw std::runtime_error(format("invalid split file idx: %d (file: %s), expected %d", idx_gguf, fname_split, idx));
|
||||
const int n_tensors_loaded = (int) weights_map.size();
|
||||
if (n_tensors != n_tensors_loaded) {
|
||||
throw std::runtime_error(format("corrupted model: %d tensors expected but %d found", n_tensors, n_tensors_loaded));
|
||||
}
|
||||
}
|
||||
|
||||
files.emplace_back(new llama_file(fname_split, "rb", use_direct_io));
|
||||
contexts.emplace_back(ctx);
|
||||
|
||||
// Save tensors data offset info of the shard.
|
||||
for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
|
||||
std::string tensor_name = std::string(cur->name);
|
||||
// make sure there is no duplicated tensor names
|
||||
if (weights_map.find(tensor_name) != weights_map.end()) {
|
||||
throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur)));
|
||||
}
|
||||
n_elements += ggml_nelements(cur);
|
||||
n_bytes += ggml_nbytes(cur);
|
||||
weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), idx, ctx_gguf.get(), cur));
|
||||
}
|
||||
LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n", __func__, n_split - 1);
|
||||
}
|
||||
|
||||
get_key(llm_kv(LLM_KV_SPLIT_TENSORS_COUNT), n_tensors);
|
||||
|
||||
// sanity check
|
||||
{
|
||||
const int n_tensors_loaded = (int) weights_map.size();
|
||||
if (n_tensors != n_tensors_loaded) {
|
||||
throw std::runtime_error(format("corrupted model: %d tensors expected but %d found", n_tensors, n_tensors_loaded));
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n", __func__, n_split - 1);
|
||||
} else {
|
||||
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
|
||||
llm_kv = LLM_KV(llm_arch_from_string(arch_name));
|
||||
}
|
||||
|
||||
n_kv = gguf_get_n_kv(meta.get());
|
||||
n_kv = gguf_get_n_kv(metadata);
|
||||
n_tensors = weights_map.size();
|
||||
|
||||
fver = (enum llama_fver) gguf_get_version(meta.get());
|
||||
fver = (enum llama_fver) gguf_get_version(metadata);
|
||||
|
||||
LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n",
|
||||
__func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver));
|
||||
@@ -729,14 +744,14 @@ llama_model_loader::llama_model_loader(
|
||||
LLAMA_LOG_INFO("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__);
|
||||
|
||||
for (int i = 0; i < n_kv; i++) {
|
||||
const char * name = gguf_get_key(meta.get(), i);
|
||||
const enum gguf_type type = gguf_get_kv_type(meta.get(), i);
|
||||
const char * name = gguf_get_key(metadata, i);
|
||||
const enum gguf_type type = gguf_get_kv_type(metadata, i);
|
||||
const std::string type_name =
|
||||
type == GGUF_TYPE_ARRAY
|
||||
? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta.get(), i)), gguf_get_arr_n(meta.get(), i))
|
||||
? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(metadata, i)), gguf_get_arr_n(metadata, i))
|
||||
: gguf_type_name(type);
|
||||
|
||||
std::string value = gguf_kv_to_str(meta.get(), i);
|
||||
std::string value = gguf_kv_to_str(metadata, i);
|
||||
const size_t MAX_VALUE_LEN = 40;
|
||||
if (value.size() > MAX_VALUE_LEN) {
|
||||
value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str());
|
||||
@@ -838,15 +853,382 @@ const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::stri
|
||||
return cur;
|
||||
}
|
||||
|
||||
struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list<int64_t> & ne, int flags) {
|
||||
LLAMA_LOG_DEBUG("%s: loading tensor %s\n", __func__, name.c_str());
|
||||
const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED));
|
||||
// checks if the weight tensor can be used with the specified buffer type and device
|
||||
static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
|
||||
GGML_ASSERT(w != nullptr);
|
||||
|
||||
if (op == GGML_OP_NONE) {
|
||||
return true;
|
||||
}
|
||||
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ ggml_tensor_overhead()*8,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
||||
if (!ctx_ptr) {
|
||||
throw std::runtime_error(format("failed to create ggml context"));
|
||||
}
|
||||
ggml_context * ctx = ctx_ptr.get();
|
||||
|
||||
ggml_tensor * op_tensor = nullptr;
|
||||
|
||||
switch (op) {
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512);
|
||||
op_tensor = ggml_get_rows(ctx, w, b);
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]);
|
||||
op_tensor = ggml_mul_mat(ctx, w, b);
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
const int n_expert_used = hparams.n_expert_used;
|
||||
GGML_ASSERT(n_expert_used > 0);
|
||||
ggml_tensor * b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512);
|
||||
ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512);
|
||||
op_tensor = ggml_mul_mat_id(ctx, w, b, ids);
|
||||
} break;
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
|
||||
op_tensor = ggml_add(ctx, a, w);
|
||||
} break;
|
||||
case GGML_OP_ADD_ID:
|
||||
{
|
||||
const int n_expert_used = hparams.n_expert_used;
|
||||
GGML_ASSERT(n_expert_used > 0);
|
||||
ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512);
|
||||
ggml_tensor * c = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512);
|
||||
op_tensor = ggml_add_id(ctx, a, w, c);
|
||||
} break;
|
||||
case GGML_OP_MUL:
|
||||
{
|
||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
|
||||
op_tensor = ggml_mul(ctx, a, w);
|
||||
} break;
|
||||
case GGML_OP_DIV:
|
||||
{
|
||||
ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, w->ne[0]);
|
||||
op_tensor = ggml_div(ctx, a, w);
|
||||
} break;
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
const int n_embd_head = hparams.n_embd_head_v;
|
||||
const int n_head = hparams.n_head();
|
||||
ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512);
|
||||
ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512);
|
||||
op_tensor = ggml_rope_ext(
|
||||
ctx, a, b, w,
|
||||
0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0
|
||||
);
|
||||
|
||||
} break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
{
|
||||
const int64_t n_seq_tokens = 512;
|
||||
const int64_t n_seqs = 3;
|
||||
ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0] - 1 + n_seq_tokens, w->ne[1], n_seqs);
|
||||
op_tensor = ggml_ssm_conv(ctx, conv_x, w);
|
||||
} break;
|
||||
case GGML_OP_SSM_SCAN:
|
||||
{
|
||||
// w is ssm_a, which is used to distinguish Mamba-1 and Mamba-2
|
||||
const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0];
|
||||
const int64_t n_head = w->ne[1];
|
||||
const int64_t head_dim = hparams.ssm_d_inner / n_head;
|
||||
const int64_t n_group = hparams.ssm_n_group ? hparams.ssm_n_group : 1;
|
||||
const int64_t n_seq_tokens = 512;
|
||||
const int64_t n_seqs = 3;
|
||||
ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs);
|
||||
ggml_tensor * x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, n_head, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_head, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * B = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * C = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs);
|
||||
op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C, ids);
|
||||
} break;
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
{
|
||||
// FIXME
|
||||
const int64_t S = 123;
|
||||
const int64_t H = 123;
|
||||
const int64_t n_tokens = 123;
|
||||
const int64_t n_seqs = 123;
|
||||
ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens);
|
||||
ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens);
|
||||
ggml_tensor * r = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens);
|
||||
ggml_tensor * tf = w;
|
||||
ggml_tensor * td = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens);
|
||||
ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H);
|
||||
op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state);
|
||||
} break;
|
||||
case GGML_OP_IM2COL:
|
||||
{
|
||||
const int n_embd_inp = hparams.n_embd_inp();
|
||||
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd_inp, w->ne[1], 1, 1);
|
||||
op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16);
|
||||
} break;
|
||||
case GGML_OP_SCALE:
|
||||
{
|
||||
op_tensor = ggml_scale(ctx, w, 1.0f);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name);
|
||||
}
|
||||
|
||||
// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
|
||||
GGML_ASSERT(w->buffer == nullptr);
|
||||
w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
|
||||
bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
|
||||
ggml_backend_buffer_free(w->buffer);
|
||||
w->buffer = nullptr;
|
||||
|
||||
return op_supported;
|
||||
}
|
||||
|
||||
// find the first buffer type in the list that can use the tensor
|
||||
static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hparams, ggml_tensor * tensor, ggml_op op, const buft_list_t * buft_list) {
|
||||
GGML_ASSERT(!buft_list->empty());
|
||||
for (const auto & cur : *buft_list) {
|
||||
ggml_backend_dev_t cur_dev = cur.first;
|
||||
ggml_backend_buffer_type_t cur_buft = cur.second;
|
||||
if (weight_buft_supported(hparams, tensor, op, cur_buft, cur_dev)) {
|
||||
return cur_buft;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
struct ggml_tensor * llama_model_loader::create_tensor(
|
||||
const llama_hparams & hparams, const buft_list_t * buft_list_cpu, const buft_list_t * buft_list_input, const buft_list_t * buft_list_output,
|
||||
const buft_list_t * buft_list_layer, const LLM_TN_IMPL & tn, const std::initializer_list<int64_t> & ne, int flags) {
|
||||
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
||||
auto it = ctx_map.find(buft);
|
||||
if (it == ctx_map.end()) {
|
||||
// one ggml context per buffer type
|
||||
int max_n_tensors = n_tensors;
|
||||
max_n_tensors += 1; // duplicated output tensor
|
||||
max_n_tensors += hparams.n_layer*2; // duplicated rope freq tensors
|
||||
if (files.empty()) {
|
||||
max_n_tensors += hparams.n_layer*256; // this should be well above what any model actually uses
|
||||
}
|
||||
const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors;
|
||||
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ ctx_size,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
ggml_context * ctx = ggml_init(params);
|
||||
if (!ctx) {
|
||||
throw std::runtime_error(format("failed to create ggml context"));
|
||||
}
|
||||
|
||||
ctx_map.emplace(buft, ctx);
|
||||
|
||||
return ctx;
|
||||
}
|
||||
return it->second.get();
|
||||
};
|
||||
|
||||
auto buft_for_tensor = [&](ggml_tensor * t_meta) -> ggml_backend_buffer_type_t {
|
||||
if (!t_meta) {
|
||||
if (flags & TENSOR_NOT_REQUIRED) {
|
||||
return nullptr;
|
||||
}
|
||||
throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str()));
|
||||
}
|
||||
|
||||
// some models use the token embedding tensor as the output, but since these are used in different layers and with different ops
|
||||
// the tensor is duplicated
|
||||
// to handle this, we check if the tensor is duplicated, and if so, we assume that it is being loaded as the output tensor
|
||||
llm_tensor tn_tensor = tn.tensor;
|
||||
if (tn.tensor == LLM_TENSOR_TOKEN_EMBD && (flags & TENSOR_DUPLICATED)) {
|
||||
tn_tensor = LLM_TENSOR_OUTPUT;
|
||||
}
|
||||
|
||||
llm_tensor_info info;
|
||||
try {
|
||||
info = llm_tensor_info_for(tn_tensor);
|
||||
} catch (const std::out_of_range & e) {
|
||||
throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str()));
|
||||
}
|
||||
|
||||
// skip unused tensors
|
||||
if (info.op == GGML_OP_NONE || (flags & TENSOR_SKIP)) {
|
||||
const size_t nbytes = ggml_nbytes(t_meta);
|
||||
LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", tn.str().c_str(), nbytes);
|
||||
|
||||
size_data -= nbytes;
|
||||
n_created++;
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// tensors with "bias" suffix are always used with GGML_OP_ADD or GGML_OP_ADD_ID
|
||||
ggml_op op;
|
||||
bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0;
|
||||
if (bias) {
|
||||
if (info.op == GGML_OP_MUL_MAT_ID) {
|
||||
op = GGML_OP_ADD_ID;
|
||||
} else {
|
||||
op = GGML_OP_ADD;
|
||||
}
|
||||
} else {
|
||||
op = info.op;
|
||||
}
|
||||
|
||||
// sanity checks
|
||||
if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) {
|
||||
if (tn.bid != -1) {
|
||||
GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str());
|
||||
}
|
||||
} else {
|
||||
if (tn.bid == -1) {
|
||||
GGML_ABORT("repeating layer tensor %s used without a layer number", tn.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// select the buffer type for this tensor
|
||||
const buft_list_t * buft_list;
|
||||
switch (info.layer) {
|
||||
case LLM_TENSOR_LAYER_INPUT:
|
||||
buft_list = buft_list_input;
|
||||
break;
|
||||
case LLM_TENSOR_LAYER_OUTPUT:
|
||||
buft_list = buft_list_output;
|
||||
break;
|
||||
case LLM_TENSOR_LAYER_REPEATING:
|
||||
GGML_ASSERT(buft_list_layer != nullptr);
|
||||
buft_list = buft_list_layer;
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str());
|
||||
}
|
||||
|
||||
ggml_backend_buffer_type_t buft = nullptr;
|
||||
|
||||
// check overrides
|
||||
if (tensor_buft_overrides) {
|
||||
std::string tensor_name = tn.str();
|
||||
for (const auto * overrides = tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) {
|
||||
std::regex pattern(overrides->pattern);
|
||||
if (std::regex_search(tensor_name, pattern)) {
|
||||
if (overrides->buft == ggml_backend_cpu_buffer_type()) {
|
||||
// when overriding to a CPU buffer, consider the extra buffer types
|
||||
buft = select_weight_buft(hparams, t_meta, op, buft_list_cpu);
|
||||
} else {
|
||||
buft = overrides->buft;
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("tensor %s (%zu MiB %s) buffer type overridden to %s\n",
|
||||
tensor_name.c_str(),
|
||||
ggml_nbytes(t_meta) / 1024 / 1024, ggml_type_name(t_meta->type),
|
||||
ggml_backend_buft_name(buft));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!buft) {
|
||||
buft = select_weight_buft(hparams, t_meta, op, buft_list);
|
||||
if (!buft) {
|
||||
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str()));
|
||||
}
|
||||
}
|
||||
|
||||
// avoid using a host buffer when using mmap
|
||||
auto * buft_dev = ggml_backend_buft_get_device(buft);
|
||||
if (use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (!cpu_dev) {
|
||||
throw std::runtime_error("no CPU backend found");
|
||||
}
|
||||
buft = ggml_backend_dev_buffer_type(cpu_dev);
|
||||
}
|
||||
|
||||
if (buft != buft_list->front().second) {
|
||||
if (n_tensors_moved == 0) {
|
||||
first_tensor_moved_name = t_meta->name;
|
||||
first_tensor_moved_type_name = ggml_type_name(t_meta->type);
|
||||
first_moved_from_buft = buft_list->front().second;
|
||||
first_moved_to_buft = buft;
|
||||
}
|
||||
n_tensors_moved++;
|
||||
}
|
||||
|
||||
return buft;
|
||||
};
|
||||
|
||||
if (files.empty()) {
|
||||
if (flags & TENSOR_SKIP_IF_VIRTUAL) {
|
||||
return nullptr;
|
||||
}
|
||||
ggml_type type = GGML_TYPE_F32;
|
||||
const int64_t tid = gguf_find_tensor(metadata, tn.str().c_str());
|
||||
if (tid != -1) {
|
||||
type = gguf_get_tensor_type(metadata, tid);
|
||||
}
|
||||
|
||||
// for tensors that are not required some of the dimensions can be invalid:
|
||||
if (flags & TENSOR_NOT_REQUIRED) {
|
||||
for (size_t dim = 0; dim < ne.size(); dim++) {
|
||||
if (ne.begin()[dim] <= 0) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor t_meta;
|
||||
memset(&t_meta, 0, sizeof(ggml_tensor));
|
||||
t_meta.type = type;
|
||||
for (size_t dim = 0; dim < GGML_MAX_DIMS; dim++) {
|
||||
t_meta.ne[dim] = dim < ne.size() ? ne.begin()[dim] : 1;
|
||||
GGML_ASSERT(t_meta.ne[dim] >= 1);
|
||||
t_meta.nb[dim] = dim == 0 ? ggml_type_size(type) : t_meta.ne[dim-1]*t_meta.nb[dim-1];
|
||||
GGML_ASSERT(t_meta.nb[dim] >= 1);
|
||||
}
|
||||
ggml_set_name(&t_meta, tn.str().c_str());
|
||||
|
||||
ggml_backend_buffer_type_t buft = buft_for_tensor(&t_meta);
|
||||
GGML_ASSERT(buft != nullptr);
|
||||
ggml_context * ctx = ctx_for_buft(buft);
|
||||
ggml_tensor * ret = ggml_dup_tensor(ctx, &t_meta);
|
||||
ggml_set_name(ret, tn.str().c_str());
|
||||
return ret;
|
||||
}
|
||||
|
||||
ggml_tensor * t_meta = get_tensor_meta(tn.str().c_str());
|
||||
ggml_backend_buffer_type_t buft = buft_for_tensor(t_meta);
|
||||
if (buft == nullptr) {
|
||||
return nullptr; // return type is ggml_tensor *
|
||||
}
|
||||
ggml_context * ctx = ctx_for_buft(buft);
|
||||
|
||||
// if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one
|
||||
if (flags & TENSOR_DUPLICATED) {
|
||||
ggml_tensor * t = ggml_get_tensor(ctx, tn.str().c_str());
|
||||
if (t) {
|
||||
return t;
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: loading tensor %s\n", __func__, tn.str().c_str());
|
||||
const struct ggml_tensor * cur = check_tensor_dims(tn.str(), ne, !(flags & TENSOR_NOT_REQUIRED));
|
||||
|
||||
if (cur == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
bool duplicated = flags & TENSOR_DUPLICATED;
|
||||
const bool duplicated = flags & TENSOR_DUPLICATED;
|
||||
|
||||
struct ggml_tensor * tensor = ggml_dup_tensor(ctx, cur);
|
||||
ggml_set_name(tensor, ggml_get_name(cur));
|
||||
@@ -858,7 +1240,6 @@ struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx
|
||||
}
|
||||
|
||||
return tensor;
|
||||
|
||||
}
|
||||
|
||||
struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list<int64_t> & ne, size_t offset, bool required) {
|
||||
@@ -893,6 +1274,11 @@ void llama_model_loader::done_getting_tensors() const {
|
||||
if (n_created != n_tensors) {
|
||||
throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created));
|
||||
}
|
||||
if (n_tensors_moved > 0) {
|
||||
LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %zu others) cannot be used with preferred buffer type %s, using %s instead\n",
|
||||
__func__, first_tensor_moved_name.c_str(), first_tensor_moved_type_name.c_str(), n_tensors_moved - 1,
|
||||
ggml_backend_buft_name(first_moved_from_buft), ggml_backend_buft_name(first_moved_to_buft));
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps) {
|
||||
@@ -974,6 +1360,12 @@ bool llama_model_loader::load_all_data(
|
||||
llama_mlocks * lmlocks,
|
||||
llama_progress_callback progress_callback,
|
||||
void * progress_callback_user_data) {
|
||||
if (files.empty()) {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
|
||||
set_tensor_data(t, set_tensor_data_ud);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
GGML_ASSERT(size_data != 0 && "call init_mappings() first");
|
||||
|
||||
std::vector<no_init<uint8_t>> read_buf;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user