mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-16 16:27:32 +03:00
Compare commits
2 Commits
b5391
...
gg/server-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
237acc7cd5 | ||
|
|
6190e1c1c9 |
@@ -73,8 +73,6 @@ add_library(${TARGET} STATIC
|
||||
minja/minja.hpp
|
||||
ngram-cache.cpp
|
||||
ngram-cache.h
|
||||
regex-partial.cpp
|
||||
regex-partial.h
|
||||
sampling.cpp
|
||||
sampling.h
|
||||
speculative.cpp
|
||||
|
||||
246
common/chat.cpp
246
common/chat.cpp
@@ -6,15 +6,6 @@
|
||||
|
||||
#include <optional>
|
||||
|
||||
static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
|
||||
auto time = std::chrono::system_clock::to_time_t(now);
|
||||
auto local_time = *std::localtime(&time);
|
||||
std::ostringstream ss;
|
||||
ss << std::put_time(&local_time, format.c_str());
|
||||
auto res = ss.str();
|
||||
return res;
|
||||
}
|
||||
|
||||
typedef minja::chat_template common_chat_template;
|
||||
|
||||
struct common_chat_templates {
|
||||
@@ -33,7 +24,6 @@ struct templates_params {
|
||||
std::string grammar;
|
||||
bool add_generation_prompt = true;
|
||||
bool extract_reasoning = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
};
|
||||
|
||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
|
||||
@@ -949,83 +939,78 @@ static void expect_tool_parameters(const std::string & name, const json & parame
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
|
||||
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
|
||||
auto builtin_tools = json::array();
|
||||
common_chat_params data;
|
||||
if (!inputs.tools.is_null()) {
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
|
||||
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
|
||||
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
|
||||
expect_tool_parameters(name, parameters, {"query"});
|
||||
} else if (name == "python" || name == "code_interpreter") {
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
|
||||
expect_tool_parameters(name, parameters, {"code"});
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::string> kvs;
|
||||
for (const auto & [key, value] : parameters.at("properties").items()) {
|
||||
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
|
||||
}
|
||||
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
|
||||
builtin_tools.push_back(name);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
|
||||
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
|
||||
if (allow_python_tag_builtin_tools) {
|
||||
handle_builtin_tool(name, parameters);
|
||||
}
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"{\" space "
|
||||
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
|
||||
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
|
||||
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
|
||||
"\"}\" space"));
|
||||
});
|
||||
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
|
||||
});
|
||||
if (!builtin_tools.empty()) {
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
||||
data.preserved_tokens.push_back("<|python_tag|>");
|
||||
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
|
||||
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
|
||||
expect_tool_parameters(name, parameters, {"query"});
|
||||
} else if (name == "python" || name == "code_interpreter") {
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
|
||||
expect_tool_parameters(name, parameters, {"code"});
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
// Allow a few empty lines on top of the usual constrained json schema space rule.
|
||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||
data.additional_stops.push_back("<|eom_id|>");
|
||||
|
||||
std::vector<std::string> kvs;
|
||||
for (const auto & [key, value] : parameters.at("properties").items()) {
|
||||
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
|
||||
}
|
||||
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
|
||||
builtin_tools.push_back(name);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
|
||||
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
|
||||
if (allow_python_tag_builtin_tools) {
|
||||
handle_builtin_tool(name, parameters);
|
||||
}
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"{\" space "
|
||||
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
|
||||
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
|
||||
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
|
||||
"\"}\" space"));
|
||||
});
|
||||
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
|
||||
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
|
||||
: COMMON_CHAT_FORMAT_LLAMA_3_X;
|
||||
} else {
|
||||
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
}
|
||||
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
|
||||
});
|
||||
if (!builtin_tools.empty()) {
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
||||
data.preserved_tokens.push_back("<|python_tag|>");
|
||||
}
|
||||
// Allow a few empty lines on top of the usual constrained json schema space rule.
|
||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||
});
|
||||
data.additional_stops.push_back("<|eom_id|>");
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
|
||||
{"date_string", format_time(inputs.now, "%d %b %Y")},
|
||||
{"tools_in_user_message", false},
|
||||
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
||||
});
|
||||
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
|
||||
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
|
||||
: COMMON_CHAT_FORMAT_LLAMA_3_X;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
|
||||
@@ -1165,7 +1150,7 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
|
||||
LOG_DBG("%s\n", __func__);
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
||||
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
|
||||
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||
});
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
@@ -1300,59 +1285,55 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
|
||||
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||
common_chat_params data;
|
||||
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
|
||||
std::string python_code_argument_name;
|
||||
auto has_raw_python = false;
|
||||
|
||||
if (!inputs.tools.is_null()) {
|
||||
std::string python_code_argument_name;
|
||||
auto has_raw_python = false;
|
||||
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
const auto & parameters = function.at("parameters");
|
||||
std::string name = function.at("name");
|
||||
if (name == "python" || name == "ipython") {
|
||||
if (!parameters.contains("type")) {
|
||||
throw std::runtime_error("Missing type in python tool");
|
||||
}
|
||||
has_raw_python = true;
|
||||
const auto & type = parameters.at("type");
|
||||
if (type == "object") {
|
||||
auto properties = parameters.at("properties");
|
||||
for (auto it = properties.begin(); it != properties.end(); ++it) {
|
||||
if (it.value().at("type") == "string") {
|
||||
if (!python_code_argument_name.empty()) {
|
||||
throw std::runtime_error("Multiple string arguments found in python tool");
|
||||
}
|
||||
python_code_argument_name = it.key();
|
||||
}
|
||||
}
|
||||
if (python_code_argument_name.empty()) {
|
||||
throw std::runtime_error("No string argument found in python tool");
|
||||
}
|
||||
} else if (type != "string") {
|
||||
throw std::runtime_error("Invalid type in python tool: " + type.dump());
|
||||
}
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
const auto & parameters = function.at("parameters");
|
||||
std::string name = function.at("name");
|
||||
if (name == "python" || name == "ipython") {
|
||||
if (!parameters.contains("type")) {
|
||||
throw std::runtime_error("Missing type in python tool");
|
||||
}
|
||||
has_raw_python = true;
|
||||
const auto & type = parameters.at("type");
|
||||
if (type == "object") {
|
||||
auto properties = parameters.at("properties");
|
||||
for (auto it = properties.begin(); it != properties.end(); ++it) {
|
||||
if (it.value().at("type") == "string") {
|
||||
if (!python_code_argument_name.empty()) {
|
||||
throw std::runtime_error("Multiple string arguments found in python tool");
|
||||
}
|
||||
python_code_argument_name = it.key();
|
||||
}
|
||||
}
|
||||
if (python_code_argument_name.empty()) {
|
||||
throw std::runtime_error("No string argument found in python tool");
|
||||
}
|
||||
} else if (type != "string") {
|
||||
throw std::runtime_error("Invalid type in python tool: " + type.dump());
|
||||
}
|
||||
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
||||
});
|
||||
if (has_raw_python) {
|
||||
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
||||
data.preserved_tokens.push_back("<|python_tag|>");
|
||||
}
|
||||
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
||||
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
|
||||
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
||||
});
|
||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
|
||||
} else {
|
||||
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
}
|
||||
if (has_raw_python) {
|
||||
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
||||
data.preserved_tokens.push_back("<|python_tag|>");
|
||||
}
|
||||
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
||||
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
|
||||
});
|
||||
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
// TODO: if (has_raw_python)
|
||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
|
||||
@@ -1612,7 +1593,6 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
params.extract_reasoning = inputs.extract_reasoning;
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.grammar = inputs.grammar;
|
||||
params.now = inputs.now;
|
||||
if (!inputs.json_schema.empty()) {
|
||||
params.json_schema = json::parse(inputs.json_schema);
|
||||
}
|
||||
@@ -1664,21 +1644,21 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
return common_chat_params_init_firefunction_v2(tmpl, params);
|
||||
}
|
||||
|
||||
// Plain handler (no tools)
|
||||
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return common_chat_params_init_without_tools(tmpl, params);
|
||||
}
|
||||
|
||||
// Functionary v3.1 (w/ tools)
|
||||
if (src.find("<|start_header_id|>") != std::string::npos
|
||||
&& src.find("<function=") != std::string::npos) {
|
||||
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
|
||||
}
|
||||
|
||||
// Llama 3.1, 3.2, 3.3 (also requires date_string so using it even w/o tools)
|
||||
// Llama 3.1, 3.2, 3.3 (w/ tools)
|
||||
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
||||
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
|
||||
}
|
||||
|
||||
// Plain handler (no tools)
|
||||
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return common_chat_params_init_without_tools(tmpl, params);
|
||||
return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
|
||||
}
|
||||
|
||||
// Mistral Nemo (w/ tools)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include <chrono>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@@ -72,7 +71,6 @@ struct common_chat_templates_inputs {
|
||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
bool parallel_tool_calls = false;
|
||||
bool extract_reasoning = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
};
|
||||
|
||||
struct common_chat_params {
|
||||
|
||||
@@ -443,25 +443,6 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
||||
s = std::move(builder);
|
||||
}
|
||||
|
||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
|
||||
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
||||
}
|
||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
|
||||
if (!str.empty() && !stop.empty()) {
|
||||
const char text_last_char = str.back();
|
||||
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
|
||||
if (stop[char_index] == text_last_char) {
|
||||
const auto current_partial = stop.substr(0, char_index + 1);
|
||||
if (string_ends_with(str, current_partial)) {
|
||||
return str.size() - char_index - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
std::string regex_escape(const std::string & s) {
|
||||
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
||||
return std::regex_replace(s, special_chars, "\\$0");
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
@@ -504,9 +503,10 @@ static bool string_starts_with(const std::string & str,
|
||||
return str.rfind(prefix, 0) == 0;
|
||||
}
|
||||
|
||||
// While we wait for C++20's std::string::ends_with...
|
||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
|
||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
|
||||
static bool string_ends_with(const std::string & str,
|
||||
const std::string & suffix) { // While we wait for C++20's std::string::ends_with...
|
||||
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
||||
}
|
||||
|
||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
||||
void string_process_escapes(std::string & input);
|
||||
|
||||
@@ -1,204 +0,0 @@
|
||||
#include "regex-partial.h"
|
||||
#include "common.h"
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
|
||||
common_regex::common_regex(const std::string & pattern) :
|
||||
pattern(pattern),
|
||||
rx(pattern),
|
||||
rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {}
|
||||
|
||||
common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
|
||||
std::smatch match;
|
||||
if (pos > input.size()) {
|
||||
throw std::runtime_error("Position out of bounds");
|
||||
}
|
||||
auto start = input.begin() + pos;
|
||||
auto found = as_match
|
||||
? std::regex_match(start, input.end(), match, rx)
|
||||
: std::regex_search(start, input.end(), match, rx);
|
||||
if (found) {
|
||||
common_regex_match res;
|
||||
res.type = COMMON_REGEX_MATCH_TYPE_FULL;
|
||||
for (size_t i = 0; i < match.size(); ++i) {
|
||||
auto begin = pos + match.position(i);
|
||||
res.groups.emplace_back(begin, begin + match.length(i));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
std::match_results<std::string::const_reverse_iterator> srmatch;
|
||||
if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) {
|
||||
auto group = srmatch[1].str();
|
||||
if (group.length() != 0) {
|
||||
auto it = srmatch[1].second.base();
|
||||
// auto position = static_cast<size_t>(std::distance(input.begin(), it));
|
||||
if ((!as_match) || it == input.begin()) {
|
||||
common_regex_match res;
|
||||
res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
|
||||
const size_t begin = std::distance(input.begin(), it);
|
||||
const size_t end = input.size();
|
||||
if (begin == std::string::npos || end == std::string::npos || begin > end) {
|
||||
throw std::runtime_error("Invalid range");
|
||||
}
|
||||
res.groups.push_back({begin, end});
|
||||
return res;
|
||||
}
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
|
||||
|
||||
Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)
|
||||
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
|
||||
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
|
||||
|
||||
- /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).*
|
||||
- /a|b/ -> (a|b).*
|
||||
- /a*?/ -> error, could match ""
|
||||
- /a*b/ -> ((?:b)?a*+).* (final repetitions become eager)
|
||||
- /.*?ab/ -> ((?:b)?a).* (merge .*)
|
||||
- /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches)
|
||||
- /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).*
|
||||
- /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).*
|
||||
- /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).*
|
||||
|
||||
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern
|
||||
(i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored)
|
||||
*/
|
||||
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
|
||||
auto it = pattern.begin();
|
||||
const auto end = pattern.end();
|
||||
|
||||
std::function<std::string()> process = [&]() {
|
||||
std::vector<std::vector<std::string>> alternatives(1);
|
||||
std::vector<std::string> * sequence = &alternatives.back();
|
||||
|
||||
while (it != end) {
|
||||
if (*it == '[') {
|
||||
auto start = it;
|
||||
++it;
|
||||
while (it != end) {
|
||||
if ((*it == '\\') && (++it != end)) {
|
||||
++it;
|
||||
} else if ((it != end) && (*it == ']')) {
|
||||
break;
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
if (it == end) {
|
||||
throw std::runtime_error("Unmatched '[' in pattern");
|
||||
}
|
||||
++it;
|
||||
sequence->push_back(std::string(start, it));
|
||||
} else if (*it == '*' || *it == '?' || *it == '+') {
|
||||
if (sequence->empty()) {
|
||||
throw std::runtime_error("Quantifier without preceding element");
|
||||
}
|
||||
sequence->back() += *it;
|
||||
auto is_star = *it == '*';
|
||||
++it;
|
||||
if (is_star) {
|
||||
if (*it == '?') {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
} else if (*it == '{') {
|
||||
if (sequence->empty()) {
|
||||
throw std::runtime_error("Repetition without preceding element");
|
||||
}
|
||||
++it;
|
||||
auto start = it;
|
||||
while (it != end && *it != '}') {
|
||||
++it;
|
||||
}
|
||||
if (it == end) {
|
||||
throw std::runtime_error("Unmatched '{' in pattern");
|
||||
}
|
||||
auto parts = string_split(std::string(start, it), ",");
|
||||
++it;
|
||||
if (parts.size() > 2) {
|
||||
throw std::runtime_error("Invalid repetition range in pattern");
|
||||
}
|
||||
|
||||
auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> {
|
||||
if (s.empty()) {
|
||||
return def;
|
||||
}
|
||||
return std::stoi(s);
|
||||
};
|
||||
auto min = parseOptInt(parts[0], 0);
|
||||
auto max = parts.size() == 1 ? min : parseOptInt(parts[1]);
|
||||
if (min && max && *max < *min) {
|
||||
throw std::runtime_error("Invalid repetition range in pattern");
|
||||
}
|
||||
// Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded)
|
||||
auto part = sequence->back();
|
||||
sequence->pop_back();
|
||||
for (int i = 0; i < *min; i++) {
|
||||
sequence->push_back(part);
|
||||
}
|
||||
if (max) {
|
||||
for (int i = *min; i < *max; i++) {
|
||||
sequence->push_back(part + "?");
|
||||
}
|
||||
} else {
|
||||
sequence->push_back(part + "*");
|
||||
}
|
||||
} else if (*it == '(') {
|
||||
++it;
|
||||
if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') {
|
||||
it += 2;
|
||||
}
|
||||
auto sub = process();
|
||||
if (*it != ')') {
|
||||
throw std::runtime_error("Unmatched '(' in pattern");
|
||||
}
|
||||
++it;
|
||||
auto & part = sequence->emplace_back("(?:");
|
||||
part += sub;
|
||||
part += ")";
|
||||
} else if (*it == ')') {
|
||||
break;
|
||||
} else if (*it == '|') {
|
||||
++it;
|
||||
alternatives.emplace_back();
|
||||
sequence = &alternatives.back();
|
||||
} else if (*it == '\\' && (++it != end)) {
|
||||
auto str = std::string("\\") + *it;
|
||||
sequence->push_back(str);
|
||||
++it;
|
||||
} else if (it != end) {
|
||||
sequence->push_back(std::string(1, *it));
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
// /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).*
|
||||
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
|
||||
// We'll do the outermost capturing group and final .* in the enclosing function.
|
||||
std::vector<std::string> res_alts;
|
||||
for (const auto & parts : alternatives) {
|
||||
auto & res = res_alts.emplace_back();
|
||||
for (size_t i = 0; i < parts.size() - 1; i++) {
|
||||
res += "(?:";
|
||||
}
|
||||
for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
|
||||
res += *it;
|
||||
if (it != parts.rend() - 1) {
|
||||
res += ")?";
|
||||
}
|
||||
}
|
||||
}
|
||||
return string_join(res_alts, "|");
|
||||
};
|
||||
auto res = process();
|
||||
if (it != end) {
|
||||
throw std::runtime_error("Unmatched '(' in pattern");
|
||||
}
|
||||
|
||||
return "(" + res + ")[\\s\\S]*";
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <regex>
|
||||
#include <string>
|
||||
|
||||
enum common_regex_match_type {
|
||||
COMMON_REGEX_MATCH_TYPE_NONE,
|
||||
COMMON_REGEX_MATCH_TYPE_PARTIAL,
|
||||
COMMON_REGEX_MATCH_TYPE_FULL,
|
||||
};
|
||||
|
||||
struct common_string_range {
|
||||
size_t begin;
|
||||
size_t end;
|
||||
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
|
||||
if (begin > end) {
|
||||
throw std::runtime_error("Invalid range");
|
||||
}
|
||||
}
|
||||
// prevent default ctor
|
||||
common_string_range() = delete;
|
||||
bool empty() const {
|
||||
return begin == end;
|
||||
}
|
||||
bool operator==(const common_string_range & other) const {
|
||||
return begin == other.begin && end == other.end;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_regex_match {
|
||||
common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE;
|
||||
std::vector<common_string_range> groups;
|
||||
|
||||
bool operator==(const common_regex_match & other) const {
|
||||
return type == other.type && groups == other.groups;
|
||||
}
|
||||
bool operator!=(const common_regex_match & other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
};
|
||||
|
||||
class common_regex {
|
||||
std::string pattern;
|
||||
std::regex rx;
|
||||
std::regex rx_reversed_partial;
|
||||
|
||||
public:
|
||||
explicit common_regex(const std::string & pattern);
|
||||
|
||||
common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const;
|
||||
|
||||
const std::string & str() const { return pattern; }
|
||||
};
|
||||
|
||||
// For testing only (pretty print of failures).
|
||||
std::string regex_to_reversed_partial_regex(const std::string & pattern);
|
||||
@@ -8519,11 +8519,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
|
||||
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
#ifdef __ARM_FEATURE_MATMUL_INT8
|
||||
assert((nrc == 2) || (nrc == 1));
|
||||
#else
|
||||
assert(nrc == 1);
|
||||
#endif
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
@@ -8534,197 +8530,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
if (nrc == 2) {
|
||||
const block_q6_K * GGML_RESTRICT x0 = x;
|
||||
const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
|
||||
const block_q8_K * GGML_RESTRICT y0 = y;
|
||||
const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
|
||||
|
||||
float32x4_t vfsum = vdupq_n_f32(0.0f);
|
||||
|
||||
for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
|
||||
const uint8_t * GGML_RESTRICT ql0 = x0->ql;
|
||||
const uint8_t * GGML_RESTRICT ql1 = x1->ql;
|
||||
const uint8_t * GGML_RESTRICT qh0 = x0->qh;
|
||||
const uint8_t * GGML_RESTRICT qh1 = x1->qh;
|
||||
const int8_t * GGML_RESTRICT qy0 = y0->qs;
|
||||
const int8_t * GGML_RESTRICT qy1 = y1->qs;
|
||||
|
||||
const uint8x16_t mone = vdupq_n_u8(0x30);
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
|
||||
int32x4_t visum = vdupq_n_s32(0);
|
||||
|
||||
// process 8 blocks per iteration, totally 16 blocks
|
||||
for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {
|
||||
int8x16_t vx0[8], vx1[8];
|
||||
|
||||
// de-quantize vx0[8]
|
||||
{
|
||||
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);
|
||||
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);
|
||||
|
||||
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
|
||||
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
|
||||
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
|
||||
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
|
||||
|
||||
vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
|
||||
vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
|
||||
vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
|
||||
vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
|
||||
|
||||
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
|
||||
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
|
||||
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
|
||||
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
|
||||
|
||||
vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
|
||||
vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
|
||||
vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
|
||||
vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
|
||||
}
|
||||
|
||||
// de-quantize vx1[8]
|
||||
{
|
||||
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);
|
||||
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);
|
||||
|
||||
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
|
||||
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
|
||||
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
|
||||
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
|
||||
|
||||
vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
|
||||
vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
|
||||
vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
|
||||
vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
|
||||
|
||||
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
|
||||
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
|
||||
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
|
||||
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
|
||||
|
||||
vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
|
||||
vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
|
||||
vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
|
||||
vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
|
||||
}
|
||||
|
||||
// process 16 elements (one block with same scale) per iteration
|
||||
// - vx = concat(ql, qh) - 32
|
||||
// - r1,r2,r3,r4 = smmla(vx, vy)
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
const int blk = j * 8 + k;
|
||||
|
||||
const int8x16_t vy0 = vld1q_s8(qy0);
|
||||
const int8x16_t vy1 = vld1q_s8(qy1);
|
||||
qy0 += 16;
|
||||
qy1 += 16;
|
||||
|
||||
const int32x4_t block_scale = {
|
||||
x0->scales[blk],
|
||||
x0->scales[blk],
|
||||
x1->scales[blk],
|
||||
x1->scales[blk],
|
||||
};
|
||||
|
||||
// calculate four results at once with outer product
|
||||
const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
|
||||
const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
|
||||
const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
|
||||
const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
|
||||
int32x4_t vr = vdupq_n_s32(0);
|
||||
vr = vmmlaq_s32(vr, vx_l, vy_l);
|
||||
vr = vmmlaq_s32(vr, vx_h, vy_h);
|
||||
|
||||
// apply block scale, will NOT overflow
|
||||
// block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits
|
||||
visum = vmlaq_s32(visum, vr, block_scale);
|
||||
}
|
||||
}
|
||||
|
||||
// adjust bias, apply superblock scale
|
||||
{
|
||||
int32_t bias[4];
|
||||
#ifdef __ARM_FEATURE_SVE
|
||||
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
|
||||
const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
|
||||
const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
|
||||
const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
|
||||
const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
|
||||
const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
|
||||
const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
|
||||
const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
|
||||
const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
|
||||
const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
|
||||
const svint64_t zero = svdup_n_s64(0);
|
||||
bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
|
||||
svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
|
||||
bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
|
||||
svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
|
||||
bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
|
||||
svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
|
||||
bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
|
||||
svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
|
||||
#else
|
||||
// NEON doesn't support int16 dot product, fallback to separated mul and add
|
||||
const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
|
||||
const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
|
||||
|
||||
int8x16_t scales_s8 = vld1q_s8(x0->scales);
|
||||
const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
|
||||
scales_s8 = vld1q_s8(x1->scales);
|
||||
const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
|
||||
|
||||
int32x4_t prod;
|
||||
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),
|
||||
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),
|
||||
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),
|
||||
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));
|
||||
bias[0] = vaddvq_s32(prod);
|
||||
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),
|
||||
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),
|
||||
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),
|
||||
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));
|
||||
bias[1] = vaddvq_s32(prod);
|
||||
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),
|
||||
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),
|
||||
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),
|
||||
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));
|
||||
bias[2] = vaddvq_s32(prod);
|
||||
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),
|
||||
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),
|
||||
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),
|
||||
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
|
||||
bias[3] = vaddvq_s32(prod);
|
||||
|
||||
#endif
|
||||
const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
|
||||
|
||||
const float32x4_t superblock_scale = {
|
||||
GGML_FP16_TO_FP32(x0->d) * y0->d,
|
||||
GGML_FP16_TO_FP32(x0->d) * y1->d,
|
||||
GGML_FP16_TO_FP32(x1->d) * y0->d,
|
||||
GGML_FP16_TO_FP32(x1->d) * y1->d,
|
||||
};
|
||||
|
||||
visum = vsubq_s32(visum, vibias);
|
||||
vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
|
||||
}
|
||||
}
|
||||
|
||||
// vfsum = ABCD -> ACBD
|
||||
// AC -> s, BD -> (s+bs)
|
||||
vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
|
||||
vst1_f32(s, vget_low_f32 (vfsum));
|
||||
vst1_f32(s + bs, vget_high_f32(vfsum));
|
||||
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef __ARM_FEATURE_SVE
|
||||
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
||||
float sum = 0;
|
||||
|
||||
@@ -282,11 +282,7 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
||||
.from_float = quantize_row_q6_K,
|
||||
.vec_dot = ggml_vec_dot_q6_K_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
||||
.nrows = 2,
|
||||
#else
|
||||
.nrows = 1,
|
||||
#endif
|
||||
},
|
||||
[GGML_TYPE_IQ2_XXS] = {
|
||||
.from_float = NULL,
|
||||
|
||||
@@ -678,14 +678,10 @@ void launch_fattn(
|
||||
) {
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
|
||||
const bool is_mla = DV == 512; // TODO better parameterization
|
||||
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
GGML_ASSERT(V || is_mla);
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
ggml_tensor * KQV = dst;
|
||||
@@ -693,10 +689,6 @@ void launch_fattn(
|
||||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT( Q->nb[0] == ggml_element_size(Q));
|
||||
GGML_ASSERT( K->nb[0] == ggml_element_size(K));
|
||||
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
|
||||
|
||||
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
||||
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
||||
@@ -721,10 +713,10 @@ void launch_fattn(
|
||||
size_t nb12 = K->nb[2];
|
||||
size_t nb13 = K->nb[3];
|
||||
|
||||
const char * V_data = V ? (const char *) V->data : nullptr;
|
||||
size_t nb21 = V ? V->nb[1] : nb11;
|
||||
size_t nb22 = V ? V->nb[2] : nb12;
|
||||
size_t nb23 = V ? V->nb[3] : nb13;
|
||||
const char * V_data = (const char *) V->data;
|
||||
size_t nb21 = V->nb[1];
|
||||
size_t nb22 = V->nb[2];
|
||||
size_t nb23 = V->nb[3];
|
||||
|
||||
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(K));
|
||||
@@ -741,7 +733,7 @@ void launch_fattn(
|
||||
nb13 = nb13*bs*sizeof(half)/ts;
|
||||
}
|
||||
|
||||
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
|
||||
if (need_f16_V && V->type != GGML_TYPE_F16) {
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(V));
|
||||
V_f16.alloc(ggml_nelements(V));
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
||||
|
||||
@@ -33,30 +33,9 @@ struct fattn_mma_f16_config< 64, 64> {
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
static constexpr int nbatch_K2 = 32;
|
||||
static constexpr int nbatch_V2 = 32;
|
||||
static constexpr int nbatch_combine = 32;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -65,30 +44,9 @@ struct fattn_mma_f16_config< 80, 80> {
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 40;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 40;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 40;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 40;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 40;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 40;
|
||||
}
|
||||
static constexpr int nbatch_K2 = 40;
|
||||
static constexpr int nbatch_V2 = 40;
|
||||
static constexpr int nbatch_combine = 40;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -97,30 +55,9 @@ struct fattn_mma_f16_config< 96, 96> {
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 48;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 48;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 48;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 48;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 48;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 48;
|
||||
}
|
||||
static constexpr int nbatch_K2 = 48;
|
||||
static constexpr int nbatch_V2 = 48;
|
||||
static constexpr int nbatch_combine = 48;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -129,30 +66,9 @@ struct fattn_mma_f16_config<112, 112> {
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 56;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 56;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 56;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 56;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 56;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 56;
|
||||
}
|
||||
static constexpr int nbatch_K2 = 56;
|
||||
static constexpr int nbatch_V2 = 56;
|
||||
static constexpr int nbatch_combine = 56;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -161,30 +77,9 @@ struct fattn_mma_f16_config<128, 128> {
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
static constexpr int nbatch_K2 = 64;
|
||||
static constexpr int nbatch_V2 = 64;
|
||||
static constexpr int nbatch_combine = 64;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -193,38 +88,9 @@ struct fattn_mma_f16_config<256, 256> {
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int cc, const int ncols) {
|
||||
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
}
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int ncols) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
#else
|
||||
GGML_UNUSED(ncols);
|
||||
return 128;
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
}
|
||||
static constexpr int nbatch_K2 = 128;
|
||||
static constexpr int nbatch_V2 = 128;
|
||||
static constexpr int nbatch_combine = 128;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -233,44 +99,9 @@ struct fattn_mma_f16_config<576, 512> {
|
||||
static constexpr int nwarps_max = 8;
|
||||
static constexpr bool Q_in_reg = false;
|
||||
static constexpr int nstages_target = 1;
|
||||
|
||||
static int get_nbatch_K2_host(const int cc, const int ncols) {
|
||||
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
||||
return ncols <= 16 ? 96 : 160;
|
||||
}
|
||||
return ncols <= 16 ? 288 : 160;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int ncols) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
return ncols <= 16 ? 96 : 160;
|
||||
#else
|
||||
return ncols <= 16 ? 288 : 160;
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int cc, const int ncols) {
|
||||
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
||||
return ncols <= 16 ? 64 : 128;
|
||||
}
|
||||
return ncols <= 16 ? 256 : 128;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int ncols) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
return ncols <= 16 ? 64 : 128;
|
||||
#else
|
||||
return ncols <= 16 ? 256 : 128;
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
static constexpr int nbatch_K2 = 160;
|
||||
static constexpr int nbatch_V2 = 128;
|
||||
static constexpr int nbatch_combine = 128;
|
||||
};
|
||||
|
||||
// ------------------------------------------------------------------------------------------------------------------
|
||||
@@ -289,7 +120,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
|
||||
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
|
||||
|
||||
auto load = [&] __device__ (auto n) {
|
||||
auto load = [&] __device__ (const int n) {
|
||||
const int stride_k = WARP_SIZE >> n;
|
||||
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
||||
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
||||
@@ -392,7 +223,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
||||
}
|
||||
}
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
const float2 * const __restrict__ Q_f2,
|
||||
const half2 * const __restrict__ K_h2,
|
||||
@@ -430,15 +261,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
constexpr int cols_per_warp = ntiles * tile_B::I;
|
||||
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
|
||||
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
|
||||
|
||||
constexpr int stride_tile_Q = DKQ/2 + 4;
|
||||
constexpr int stride_tile_K = nbatch_K2 + 4;
|
||||
|
||||
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
||||
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
||||
constexpr int stride_tile_Q = DKQ/2 + 4;
|
||||
constexpr int stride_tile_K = c::nbatch_K2 + 4;
|
||||
constexpr int stride_tile_V = c::nbatch_V2 + 4;
|
||||
|
||||
const int k_VKQ_0 = kb0 * c::nbatch_fa;
|
||||
tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
|
||||
@@ -449,13 +275,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
|
||||
|
||||
if constexpr (nstages > 1) {
|
||||
static_assert(!mla, "multi-stage loading not implemented for MLA");
|
||||
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
|
||||
static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
|
||||
constexpr bool use_cp_async = true;
|
||||
cp_async_wait_all();
|
||||
__syncthreads();
|
||||
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
||||
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
|
||||
(V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V);
|
||||
} else {
|
||||
constexpr bool use_cp_async = nstages == 1;
|
||||
if (ncols2 > 1 || mask_h2) {
|
||||
@@ -464,8 +289,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
|
||||
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
|
||||
for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) {
|
||||
const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2;
|
||||
const int k0_diff = k0_stop - k0_start;
|
||||
|
||||
if (nstages <= 1) {
|
||||
@@ -712,21 +537,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
|
||||
}
|
||||
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
||||
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
|
||||
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// For MLA K and V have the same data.
|
||||
// Therefore, iterate over V in reverse and re-use the data if possible.
|
||||
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
|
||||
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
|
||||
#pragma unroll
|
||||
for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
|
||||
const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
|
||||
const int i0_diff = i0_stop - i0_start;
|
||||
for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) {
|
||||
const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV;
|
||||
const int i0_diff = i0_stop - i0_start;
|
||||
|
||||
if (nstages <= 1 && i0_start < reusable_cutoff) {
|
||||
if (nstages <= 1) {
|
||||
constexpr bool use_cp_async = nstages == 1;
|
||||
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
||||
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
||||
@@ -735,7 +555,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
|
||||
|
||||
// Calculate VKQ tile:
|
||||
#pragma unroll
|
||||
@@ -746,7 +565,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
|
||||
|
||||
tile_A A;
|
||||
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
||||
load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
||||
if (ntiles == 1) {
|
||||
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
|
||||
} else {
|
||||
@@ -777,7 +596,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const float2 * const __restrict__ Q_f2,
|
||||
const half2 * const __restrict__ K_h2,
|
||||
@@ -813,16 +632,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
constexpr int cols_per_warp = ntiles * tile_B::I;
|
||||
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
|
||||
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
|
||||
|
||||
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
|
||||
|
||||
constexpr int stride_tile_Q = DKQ/2 + 4;
|
||||
constexpr int stride_tile_K = nbatch_K2 + 4;
|
||||
constexpr int stride_tile_Q = DKQ/2 + 4;
|
||||
constexpr int stride_tile_K = c::nbatch_K2 + 4;
|
||||
constexpr int stride_tile_V = c::nbatch_V2 + 4;
|
||||
|
||||
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
||||
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
||||
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
|
||||
|
||||
extern __shared__ half2 tile_Q[];
|
||||
@@ -910,26 +726,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
|
||||
// Preload mask and K data for first iteration when using cp_async with multiple stages:
|
||||
if constexpr (nstages > 1) {
|
||||
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
|
||||
static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
|
||||
constexpr bool use_cp_async = true;
|
||||
if (ncols2 > 1 || mask_h2) {
|
||||
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
|
||||
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
|
||||
}
|
||||
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
||||
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
||||
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K);
|
||||
}
|
||||
|
||||
// Iterate over ne11 == previous tokens:
|
||||
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
|
||||
constexpr bool last_iter = false;
|
||||
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
||||
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
||||
}
|
||||
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
||||
constexpr bool last_iter = true;
|
||||
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
||||
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
||||
}
|
||||
@@ -958,7 +774,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
||||
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
|
||||
|
||||
constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols);
|
||||
constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4;
|
||||
constexpr int tile_stride = nbatch_combine + 4;
|
||||
static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
|
||||
|
||||
@@ -1196,7 +1012,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap>
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
static __global__ void flash_attn_ext_f16(
|
||||
const char * __restrict__ Q,
|
||||
@@ -1241,14 +1057,6 @@ static __global__ void flash_attn_ext_f16(
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
if (ncols1*ncols2 > 32) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
|
||||
static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
|
||||
|
||||
typedef fattn_mma_f16_config<DKQ, DV> c;
|
||||
|
||||
@@ -1259,10 +1067,9 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int stride_Q1 = nb01 / sizeof(float2);
|
||||
const int stride_Q2 = nb02 / sizeof(float2);
|
||||
const int stride_K = nb11 / sizeof(half2);
|
||||
const int stride_V = nb21 / sizeof(half2);
|
||||
const int stride_mask = nb31 / sizeof(half2);
|
||||
|
||||
const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
|
||||
|
||||
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
|
||||
@@ -1285,11 +1092,10 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||
@@ -1298,12 +1104,12 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
||||
if (kb0_start == 0) {
|
||||
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
} else {
|
||||
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
}
|
||||
@@ -1324,11 +1130,10 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape
|
||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||
@@ -1336,7 +1141,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||
constexpr bool needs_fixup = false;
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
#else
|
||||
@@ -1362,6 +1167,10 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
|
||||
typedef fattn_mma_f16_config<DKQ, DV> c;
|
||||
|
||||
constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2;
|
||||
constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2;
|
||||
constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine;
|
||||
|
||||
const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
|
||||
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
@@ -1371,21 +1180,15 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
|
||||
constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
|
||||
|
||||
constexpr bool mla = DKQ == 576;
|
||||
|
||||
const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols);
|
||||
const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols);
|
||||
const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
|
||||
|
||||
static_assert(DKQ % tile_B::J == 0, "bad DKQ");
|
||||
static_assert(DV % tile_A::J == 0, "bad DV");
|
||||
static_assert(ncols % cols_per_warp == 0, "bad ncols");
|
||||
|
||||
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
|
||||
|
||||
const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
|
||||
|
||||
@@ -1399,7 +1202,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
fattn_kernel_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
|
||||
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||
@@ -1410,7 +1213,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
|
||||
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
|
||||
template <int DKQ, int DV, int ncols2>
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
if constexpr (ncols2 <= 8) {
|
||||
@@ -25,7 +24,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
|
||||
return;
|
||||
}
|
||||
|
||||
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) {
|
||||
if (Q->ne[1] <= 32/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -3222,7 +3222,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
||||
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
|
||||
if (!new_mma_available(cc)) {
|
||||
if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) {
|
||||
return false;
|
||||
}
|
||||
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
|
||||
|
||||
@@ -122,7 +122,6 @@ void ggml_cuda_mul_mat_q(
|
||||
const int64_t s13 = src1->nb[3] / ts_src1;
|
||||
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
|
||||
ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
|
||||
@@ -206,7 +205,6 @@ void ggml_cuda_mul_mat_q(
|
||||
const int64_t s13 = src1->nb[2] / ts_src1;
|
||||
quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
|
||||
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
|
||||
|
||||
@@ -56,13 +56,13 @@ static __global__ void quantize_mmq_q8_1(
|
||||
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
|
||||
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
|
||||
|
||||
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4;
|
||||
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i1 = blockIdx.x;
|
||||
const int64_t i1 = blockIdx.y;
|
||||
const int64_t i2 = blockIdx.z % ne2;
|
||||
const int64_t i3 = blockIdx.z / ne2;
|
||||
|
||||
@@ -75,8 +75,8 @@ static __global__ void quantize_mmq_q8_1(
|
||||
|
||||
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
|
||||
|
||||
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel
|
||||
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel
|
||||
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
|
||||
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y; // block index in channel
|
||||
const int64_t iqs = i0 % (4*QK8_1); // quant index in block
|
||||
|
||||
// Load 4 floats per thread and calculate max. abs. value between them:
|
||||
@@ -166,9 +166,8 @@ void quantize_mmq_q8_1_cuda(
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
|
||||
|
||||
// ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
|
||||
const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
|
||||
const dim3 num_blocks(ne1, block_num_y, ne2*ne3);
|
||||
const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
|
||||
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
|
||||
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
|
||||
switch (mmq_get_q8_1_ds_layout(type_src0)) {
|
||||
case MMQ_Q8_1_DS_LAYOUT_D4:
|
||||
|
||||
@@ -1704,12 +1704,10 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
if (kv_self != nullptr) {
|
||||
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
||||
kv_self->state_write(io);
|
||||
}
|
||||
kv_self->state_write(io);
|
||||
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
@@ -441,13 +441,6 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
|
||||
|
||||
void llama_kv_cache_unified::set_full() {
|
||||
n = size;
|
||||
|
||||
// when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
|
||||
// affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
|
||||
// we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so
|
||||
// setting it to 0 is the simplest way to achieve that
|
||||
// ref: https://github.com/ggml-org/llama.cpp/issues/13359
|
||||
head = 0;
|
||||
}
|
||||
|
||||
llama_sbatch llama_kv_cache_unified::sbatch_init(
|
||||
@@ -1719,7 +1712,6 @@ void llama_kv_cache_recurrent::defrag_sched(float thold) {
|
||||
|
||||
void llama_kv_cache_recurrent::set_full() {
|
||||
n = size;
|
||||
head = 0;
|
||||
}
|
||||
|
||||
llama_sbatch llama_kv_cache_recurrent::sbatch_init(
|
||||
|
||||
@@ -171,8 +171,11 @@ public:
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
||||
uint32_t size = 0; // total number of cells, shared across all sequences
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_impl also uses it, so it
|
||||
// cannot be freely changed after a slot has been allocated.
|
||||
uint32_t head = 0;
|
||||
uint32_t size = 0;
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||
|
||||
// computed before each graph build
|
||||
@@ -340,8 +343,11 @@ public:
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
||||
uint32_t size = 0; // total number of cells, shared across all sequences
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_impl also uses it, so it
|
||||
// cannot be freely changed after a slot has been allocated.
|
||||
uint32_t head = 0;
|
||||
uint32_t size = 0;
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||
|
||||
// computed before each graph build
|
||||
|
||||
@@ -822,18 +822,13 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps
|
||||
mappings.reserve(files.size());
|
||||
mmaps_used.reserve(files.size());
|
||||
for (const auto & file : files) {
|
||||
bool is_numa = false;
|
||||
|
||||
auto * dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (dev) {
|
||||
auto * reg = ggml_backend_dev_backend_reg(dev);
|
||||
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
|
||||
if (is_numa_fn) {
|
||||
is_numa = is_numa_fn();
|
||||
}
|
||||
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
|
||||
if (!reg) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
|
||||
std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa);
|
||||
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
|
||||
std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa_fn());
|
||||
mmaps_used.emplace_back(mapping->size(), 0);
|
||||
if (mlock_mmaps) {
|
||||
std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock());
|
||||
|
||||
@@ -12218,9 +12218,6 @@ struct llm_build_granite : public llm_graph_context {
|
||||
|
||||
// inp_pos - built only if rope enabled
|
||||
ggml_tensor * inp_pos = nullptr;
|
||||
if (use_rope) {
|
||||
inp_pos = build_inp_pos();
|
||||
}
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
@@ -12263,6 +12260,10 @@ struct llm_build_granite : public llm_graph_context {
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
if (use_rope) {
|
||||
|
||||
if (!inp_pos) {
|
||||
inp_pos = build_inp_pos();
|
||||
}
|
||||
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, rope_factors,
|
||||
|
||||
@@ -144,7 +144,6 @@ endif()
|
||||
|
||||
llama_build_and_test(test-log.cpp)
|
||||
llama_build_and_test(test-chat-template.cpp)
|
||||
llama_build_and_test(test-regex-partial.cpp)
|
||||
|
||||
# this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)
|
||||
if (NOT WIN32)
|
||||
|
||||
@@ -832,9 +832,7 @@ static void test_template_output_parsers() {
|
||||
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||
common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
||||
common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
||||
|
||||
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
||||
|
||||
@@ -1,288 +0,0 @@
|
||||
// Tests common_regex (esp. its partial final matches support).
|
||||
|
||||
#include "common.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
|
||||
template <class T> static void assert_equals(const T & expected, const T & actual) {
|
||||
if (expected != actual) {
|
||||
std::cerr << "Expected: " << expected << std::endl;
|
||||
std::cerr << " Actual: " << actual << std::endl;
|
||||
std::cerr << std::flush;
|
||||
throw std::runtime_error("Test failed");
|
||||
}
|
||||
}
|
||||
|
||||
struct test_case {
|
||||
std::string pattern;
|
||||
struct input_output {
|
||||
std::string input;
|
||||
common_regex_match output;
|
||||
};
|
||||
std::vector<input_output> inputs_outputs;
|
||||
};
|
||||
|
||||
static std::string common_regex_match_type_name(common_regex_match_type type) {
|
||||
switch (type) {
|
||||
case COMMON_REGEX_MATCH_TYPE_NONE:
|
||||
return "COMMON_REGEX_MATCH_TYPE_NONE";
|
||||
case COMMON_REGEX_MATCH_TYPE_PARTIAL:
|
||||
return "COMMON_REGEX_MATCH_TYPE_PARTIAL";
|
||||
case COMMON_REGEX_MATCH_TYPE_FULL:
|
||||
return "COMMON_REGEX_MATCH_TYPE_FULL";
|
||||
}
|
||||
return "?";
|
||||
}
|
||||
|
||||
static void test_regex() {
|
||||
printf("[%s]\n", __func__);
|
||||
auto test = [](const test_case & test_case) {
|
||||
common_regex cr(test_case.pattern);
|
||||
std::cout << "Testing pattern: /" << test_case.pattern << "/\n";
|
||||
// std::cout << " partial rev: " << cr.reversed_partial_pattern.str() << '\n';
|
||||
for (const auto & input_output : test_case.inputs_outputs) {
|
||||
std::cout << " Input: " << input_output.input << '\n';
|
||||
auto m = cr.search(input_output.input, 0);
|
||||
if (m != input_output.output) {
|
||||
auto match_to_str = [&](const std::optional<common_regex_match> & m) {
|
||||
std::ostringstream ss;
|
||||
if (m->type == COMMON_REGEX_MATCH_TYPE_NONE) {
|
||||
ss << "<no match>";
|
||||
} else {
|
||||
GGML_ASSERT(!input_output.output.groups.empty());
|
||||
std::vector<std::string> parts;
|
||||
for (const auto & g : m->groups) {
|
||||
parts.push_back("{" + std::to_string(g.begin) + ", " + std::to_string(g.end) + "}");
|
||||
}
|
||||
ss << "{" << common_regex_match_type_name(m->type) << ", {" << string_join(parts, ", ") << "}}";
|
||||
}
|
||||
return ss.str();
|
||||
};
|
||||
std::cout << " Expected: " << match_to_str(input_output.output) << '\n';
|
||||
std::cout << " Got: " << match_to_str(m) << '\n';
|
||||
std::cout << " Inverted pattern: /" << regex_to_reversed_partial_regex(test_case.pattern) << "/\n";
|
||||
|
||||
throw std::runtime_error("Test failed");
|
||||
}
|
||||
}
|
||||
};
|
||||
test({
|
||||
"a",
|
||||
{
|
||||
{"a", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
|
||||
{"b", {COMMON_REGEX_MATCH_TYPE_NONE, {}}},
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
|
||||
{"ba", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 2}}}},
|
||||
}
|
||||
});
|
||||
test({
|
||||
"abcd",
|
||||
{
|
||||
{"abcd", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
|
||||
{"abcde", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
|
||||
{"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
|
||||
{"d", {}},
|
||||
{"bcd", {}},
|
||||
{"cde", {}},
|
||||
{"cd", {}},
|
||||
{"yeah ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{5, 7}}}},
|
||||
{"abbie", {}},
|
||||
{"", {}},
|
||||
}
|
||||
});
|
||||
test({
|
||||
".*?ab",
|
||||
{
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
|
||||
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
|
||||
{"dab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
|
||||
{"dabc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
|
||||
{"da", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||
{"d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
|
||||
}
|
||||
});
|
||||
test({
|
||||
"a.*?b",
|
||||
{
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
|
||||
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
|
||||
{"a b", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
|
||||
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
|
||||
{"argh", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
|
||||
{"d", {}},
|
||||
{"b", {}},
|
||||
}
|
||||
});
|
||||
test({
|
||||
"ab(?:cd){2,4}ef",
|
||||
{
|
||||
// {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, 0, {}}},
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||
{"abcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
|
||||
{"abcde", {}},
|
||||
{"abcdef", {}},
|
||||
{"abcdcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
|
||||
{"abcdcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 7}}}},
|
||||
{"abcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
|
||||
{"abcdcdcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 12}}}},
|
||||
{"abcdcdcdcdcdef", {}},
|
||||
{"abcde", {}},
|
||||
{"yea", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{2, 3}}}},
|
||||
}
|
||||
});
|
||||
test({
|
||||
"a(?:rte| pure )fact",
|
||||
{
|
||||
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
|
||||
{"art", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
|
||||
{"artefa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
|
||||
{"fact", {}},
|
||||
{"an arte", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{3, 7}}}},
|
||||
{"artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
|
||||
{"an artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{3, 11}}}},
|
||||
{"a pure", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
|
||||
{"a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 11}}}},
|
||||
{"it's a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{5, 16}}}},
|
||||
{"" , {}},
|
||||
{"pure", {}},
|
||||
{"pure fact", {}},
|
||||
}
|
||||
});
|
||||
test({
|
||||
"abc",
|
||||
{
|
||||
{" abcc", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 4}}}},
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
|
||||
{" ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{1, 3}}}},
|
||||
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
|
||||
{"b", {}},
|
||||
{"c", {}},
|
||||
{"", {}},
|
||||
}
|
||||
});
|
||||
|
||||
test({
|
||||
"(?:abc)?\\s*def",
|
||||
{
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||
{"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
|
||||
{"abc ", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
|
||||
{"abc d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
|
||||
{"abc de", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
|
||||
{"abc def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
|
||||
{"abc defg", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
|
||||
{"abc defgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
|
||||
{"abcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
|
||||
{"abcdefgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 6}}}},
|
||||
{" d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||
{"def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
|
||||
}
|
||||
});
|
||||
|
||||
test({
|
||||
"a+b",
|
||||
{
|
||||
{"aaab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
|
||||
{"aaa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
|
||||
}
|
||||
});
|
||||
|
||||
test({
|
||||
"(?:"
|
||||
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
|
||||
"(" // match 2 (open_tag)
|
||||
"<tool_call>"
|
||||
"|<function_call>"
|
||||
"|<tool>"
|
||||
"|<tools>"
|
||||
"|<response>"
|
||||
"|<json>"
|
||||
"|<xml>"
|
||||
"|<JSON>"
|
||||
")?"
|
||||
"(\\s*\\{\\s*\"name\"\\s*:)" // match 3 (named tool call)
|
||||
")"
|
||||
"|<function=([^>]+)>" // match 4 (function name)
|
||||
"|<function name=\"([^\"]+)\">", // match 5 (function name again)
|
||||
{
|
||||
{"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}, {54, 54}, {54, 54}, {0, 8}, {54, 54}, {54, 54}}}},
|
||||
{"<tool_call> {\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 18}}}},
|
||||
{"<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 17}}}},
|
||||
{"Let's call something\n<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{21, 38}}}},
|
||||
{"Ok then<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 24}}}},
|
||||
{"{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
|
||||
{"Ok then{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 13}}}},
|
||||
{"<tool_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 20}, {66, 66}, {0, 11}, {11, 20}, {66, 66}, {66, 66}}}},
|
||||
{"<function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 24}, {70, 70}, {0, 15}, {15, 24}, {70, 70}, {70, 70}}}},
|
||||
{"<function name=\"special_function\"> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 34}, {89, 89}, {89, 89}, {89, 89}, {89, 89}, {16, 32}}}},
|
||||
{"<function=all>", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 14}, {14, 14}, {14, 14}, {14, 14}, {10, 13}, {14, 14}}}},
|
||||
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static void test_regex_to_reversed_partial_regex() {
|
||||
printf("[%s]\n", __func__);
|
||||
|
||||
assert_equals<std::string>(
|
||||
"((?:(?:c)?b)?a)[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("abc"));
|
||||
|
||||
assert_equals<std::string>(
|
||||
"(a+)[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("a+"));
|
||||
|
||||
assert_equals<std::string>(
|
||||
"(a*)[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("a*"));
|
||||
|
||||
assert_equals<std::string>(
|
||||
"(a?)[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("a?"));
|
||||
|
||||
assert_equals<std::string>(
|
||||
"([a-z])[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("[a-z]"));
|
||||
|
||||
assert_equals<std::string>(
|
||||
"((?:\\w+)?[a-z])[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("[a-z]\\w+"));
|
||||
|
||||
assert_equals<std::string>(
|
||||
"((?:a|b))[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("(?:a|b)"));
|
||||
assert_equals<std::string>(
|
||||
"((?:(?:(?:d)?c)?b)?a)[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("abcd"));
|
||||
assert_equals<std::string>(
|
||||
"((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ??
|
||||
regex_to_reversed_partial_regex("a*b"));
|
||||
assert_equals<std::string>(
|
||||
"((?:(?:b)?a)?.*)[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex(".*?ab"));
|
||||
assert_equals<std::string>(
|
||||
"((?:(?:b)?.*)?a)[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("a.*?b"));
|
||||
assert_equals<std::string>(
|
||||
"((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("a(bc)d"));
|
||||
assert_equals<std::string>(
|
||||
"((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("a(bc|de)"));
|
||||
assert_equals<std::string>(
|
||||
"((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("ab{2,4}c"));
|
||||
}
|
||||
|
||||
int main() {
|
||||
test_regex_to_reversed_partial_regex();
|
||||
test_regex();
|
||||
std::cout << "All tests passed.\n";
|
||||
}
|
||||
@@ -1736,7 +1736,7 @@ struct sql_printer : public printer {
|
||||
}
|
||||
};
|
||||
|
||||
static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
|
||||
static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
|
||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
@@ -1753,19 +1753,14 @@ static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
|
||||
for (int i = 1; i < n_tokens; i++) {
|
||||
tokens[i] = std::rand() % n_vocab;
|
||||
}
|
||||
int res = llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
|
||||
if (res != 0) {
|
||||
fprintf(stderr, "%s: failed to decode prompt batch, res = %d\n", __func__, res);
|
||||
return false;
|
||||
}
|
||||
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
|
||||
n_processed += n_tokens;
|
||||
}
|
||||
|
||||
llama_synchronize(ctx);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool test_gen(llama_context * ctx, int n_gen, int n_threads) {
|
||||
static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
|
||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
@@ -1775,15 +1770,10 @@ static bool test_gen(llama_context * ctx, int n_gen, int n_threads) {
|
||||
llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
|
||||
|
||||
for (int i = 0; i < n_gen; i++) {
|
||||
int res = llama_decode(ctx, llama_batch_get_one(&token, 1));
|
||||
if (res != 0) {
|
||||
fprintf(stderr, "%s: failed to decode generation batch, res = %d\n", __func__, res);
|
||||
return false;
|
||||
}
|
||||
llama_decode(ctx, llama_batch_get_one(&token, 1));
|
||||
llama_synchronize(ctx);
|
||||
token = std::rand() % n_vocab;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static void llama_null_log_callback(enum ggml_log_level level, const char * text, void * user_data) {
|
||||
@@ -1927,21 +1917,13 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count);
|
||||
}
|
||||
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
|
||||
bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
|
||||
if (!res) {
|
||||
fprintf(stderr, "%s: error: failed to run prompt warmup\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
|
||||
}
|
||||
if (t.n_gen > 0) {
|
||||
if (params.progress) {
|
||||
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup generation run\n", params_idx, params_count);
|
||||
}
|
||||
bool res = test_gen(ctx, 1, t.n_threads);
|
||||
if (!res) {
|
||||
fprintf(stderr, "%s: error: failed to run gen warmup\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
test_gen(ctx, 1, t.n_threads);
|
||||
}
|
||||
|
||||
for (int i = 0; i < params.reps; i++) {
|
||||
@@ -1952,11 +1934,7 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
|
||||
i + 1, params.reps);
|
||||
}
|
||||
bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
|
||||
if (!res) {
|
||||
fprintf(stderr, "%s: error: failed to run depth\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
|
||||
}
|
||||
|
||||
uint64_t t_start = get_time_ns();
|
||||
@@ -1966,22 +1944,14 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "llama-bench: benchmark %d/%zu: prompt run %d/%d\n", params_idx, params_count,
|
||||
i + 1, params.reps);
|
||||
}
|
||||
bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
|
||||
if (!res) {
|
||||
fprintf(stderr, "%s: error: failed to run prompt\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
|
||||
}
|
||||
if (t.n_gen > 0) {
|
||||
if (params.progress) {
|
||||
fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count,
|
||||
i + 1, params.reps);
|
||||
}
|
||||
bool res = test_gen(ctx, t.n_gen, t.n_threads);
|
||||
if (!res) {
|
||||
fprintf(stderr, "%s: error: failed to run gen\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
test_gen(ctx, t.n_gen, t.n_threads);
|
||||
}
|
||||
|
||||
uint64_t t_ns = get_time_ns() - t_start;
|
||||
|
||||
@@ -1429,7 +1429,7 @@ struct server_slot {
|
||||
pos = text.find(word, from_pos);
|
||||
} else {
|
||||
// otherwise, partial stop
|
||||
pos = string_find_partial_stop(text, word);
|
||||
pos = find_partial_stop_string(word, text);
|
||||
}
|
||||
|
||||
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
|
||||
@@ -2951,8 +2951,7 @@ struct server_context {
|
||||
llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
|
||||
llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
|
||||
|
||||
// add generated tokens to cache
|
||||
{
|
||||
if (slot.params.cache_prompt) {
|
||||
llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy
|
||||
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
|
||||
new_tokens[i - n_discard] = new_tokens[i];
|
||||
@@ -2997,7 +2996,10 @@ struct server_context {
|
||||
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
|
||||
|
||||
slot.n_past += 1;
|
||||
slot.cache_tokens.push_back(slot.sampled);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
slot.cache_tokens.push_back(slot.sampled);
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
|
||||
slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
|
||||
@@ -3169,11 +3171,6 @@ struct server_context {
|
||||
|
||||
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
|
||||
}
|
||||
} else {
|
||||
// if we don't cache the prompt, we have to remove the entire KV cache
|
||||
llama_kv_self_seq_rm(ctx, slot.id, 0, -1);
|
||||
slot.n_past = 0;
|
||||
slot.cache_tokens.clear();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3207,7 +3204,7 @@ struct server_context {
|
||||
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
||||
|
||||
// remove the non-common part from the cache
|
||||
slot.cache_tokens.keep_first(slot.n_past);
|
||||
slot.cache_tokens.resize(slot.n_past);
|
||||
|
||||
// check if we should process the image
|
||||
if (slot.n_past < slot.n_prompt_tokens
|
||||
@@ -3224,8 +3221,7 @@ struct server_context {
|
||||
continue;
|
||||
}
|
||||
|
||||
// add the image chunk to cache
|
||||
{
|
||||
if (slot.params.cache_prompt) {
|
||||
const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past);
|
||||
slot.cache_tokens.push_back(chunk.get()); // copy
|
||||
}
|
||||
@@ -3246,7 +3242,9 @@ struct server_context {
|
||||
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
|
||||
slot.cache_tokens.push_back(cur_tok);
|
||||
if (slot.params.cache_prompt) {
|
||||
slot.cache_tokens.push_back(cur_tok);
|
||||
}
|
||||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
slot.n_past++;
|
||||
|
||||
@@ -196,18 +196,6 @@ def test_cache_vs_nocache_prompt():
|
||||
assert res_cache.body["content"] == res_no_cache.body["content"]
|
||||
|
||||
|
||||
def test_nocache_long_input_prompt():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is"*32,
|
||||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
"cache_prompt": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_completion_with_tokens_input():
|
||||
global server
|
||||
server.temperature = 0.0
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
import pytest
|
||||
|
||||
# ensure grandparent path is in sys.path
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
from unit.test_tool_call import TEST_TOOL
|
||||
path = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(path))
|
||||
|
||||
import datetime
|
||||
from utils import *
|
||||
|
||||
server: ServerProcess
|
||||
|
||||
TIMEOUT_SERVER_START = 15*60
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.model_alias = "tinyllama-2"
|
||||
server.server_port = 8081
|
||||
server.n_slots = 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
|
||||
@pytest.mark.parametrize("template_name,format", [
|
||||
("meta-llama-Llama-3.3-70B-Instruct", "%d %b %Y"),
|
||||
("fireworks-ai-llama-3-firefunction-v2", "%b %d %Y"),
|
||||
])
|
||||
def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
|
||||
global server
|
||||
server.jinja = True
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
|
||||
res = server.make_request("POST", "/apply-template", data={
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is today?"},
|
||||
],
|
||||
"tools": tools,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
prompt = res.body["prompt"]
|
||||
|
||||
today_str = datetime.date.today().strftime(format)
|
||||
assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"
|
||||
@@ -109,7 +109,7 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
|
||||
])
|
||||
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
|
||||
global server
|
||||
n_predict = 1024
|
||||
n_predict = 512
|
||||
# server = ServerPreset.stories15m_moe()
|
||||
server.jinja = True
|
||||
server.n_predict = n_predict
|
||||
|
||||
@@ -1153,7 +1153,7 @@ public:
|
||||
tokens.clear();
|
||||
}
|
||||
|
||||
void keep_first(size_t n) {
|
||||
void resize(size_t n) {
|
||||
GGML_ASSERT(n <= tokens.size());
|
||||
if (has_mtmd) {
|
||||
// we throw an error if we try to remove a token in the middle of an image
|
||||
|
||||
Reference in New Issue
Block a user