mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-09 16:17:31 +03:00
Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
009a113326 | ||
|
|
c8ac02fa1b | ||
|
|
4ef9301e4d | ||
|
|
ddf03c6d9a | ||
|
|
26229755c5 | ||
|
|
057dba336e | ||
|
|
501aeed18f | ||
|
|
0ec191e1d7 | ||
|
|
243532e556 | ||
|
|
5e9c635463 | ||
|
|
9949ad08f6 | ||
|
|
3ee9da0e4f | ||
|
|
75511a8d7e | ||
|
|
b54cb2e3d0 | ||
|
|
8a65a7a8ee | ||
|
|
8a132faaa0 | ||
|
|
4293919068 | ||
|
|
d12cc3d1ca | ||
|
|
2dcb7f74ed | ||
|
|
660600081f | ||
|
|
d9a12c82f0 | ||
|
|
4a05e0c566 | ||
|
|
e9fd96283d | ||
|
|
3ba12fed0a | ||
|
|
5473949070 | ||
|
|
dcdcbad42a | ||
|
|
5764d7c6a6 | ||
|
|
87f4744a80 | ||
|
|
85d482e6b6 | ||
|
|
ae65fbdf33 | ||
|
|
3bd9aa1f92 | ||
|
|
ece522f98c |
8
.github/labeler.yml
vendored
8
.github/labeler.yml
vendored
@@ -73,10 +73,18 @@ android:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- examples/llama.android/**
|
||||
server/webui:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- tools/server/webui/**
|
||||
- tools/server/public/**
|
||||
server:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- tools/server/**
|
||||
|
||||
|
||||
|
||||
ggml:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
|
||||
@@ -332,58 +332,36 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
const auto & inputs = ctx.inputs;
|
||||
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
auto until_suffix = p.rule("until-suffix", p.until(arguments.value_suffix));
|
||||
|
||||
common_peg_parser tool_choice = p.choice();
|
||||
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & func = tool.at("function");
|
||||
std::string name = func.at("name");
|
||||
const auto & params = func.contains("parameters") ? func.at("parameters") : json::object();
|
||||
auto params = func.contains("parameters") ? func.at("parameters") : json::object();
|
||||
const auto & properties = params.contains("properties") ? params.at("properties") : json::object();
|
||||
|
||||
std::set<std::string> required;
|
||||
if (params.contains("required")) {
|
||||
params.at("required").get_to(required);
|
||||
}
|
||||
|
||||
auto schema_info = common_schema_info();
|
||||
schema_info.resolve_refs(params);
|
||||
|
||||
// Build parser for each argument, separating required and optional
|
||||
std::vector<common_peg_parser> required_parsers;
|
||||
std::vector<common_peg_parser> optional_parsers;
|
||||
for (const auto & [param_name, param_schema] : properties.items()) {
|
||||
bool is_required = required.find(param_name) != required.end();
|
||||
std::string type = "object";
|
||||
if (param_schema.contains("type")) {
|
||||
const auto & type_obj = param_schema.at("type");
|
||||
if (type_obj.is_string()) {
|
||||
type_obj.get_to(type);
|
||||
} else if (type_obj.is_array()) {
|
||||
// Handle nullable types like ["string", "null"]
|
||||
for (const auto & t : type_obj) {
|
||||
if (t.is_string() && t.get<std::string>() != "null") {
|
||||
type = t.get<std::string>();
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else if (type_obj.is_object()) {
|
||||
if (type_obj.contains("type") && type_obj.at("type").is_string()) {
|
||||
type_obj.at("type").get_to(type);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Infer string type from enum values when type is unspecified
|
||||
if (type == "object" && param_schema.contains("enum")) {
|
||||
const auto & enum_vals = param_schema.at("enum");
|
||||
if (enum_vals.is_array()) {
|
||||
for (const auto & v : enum_vals) {
|
||||
if (v.is_string()) {
|
||||
type = "string";
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
bool is_required = required.find(param_name) != required.end();
|
||||
|
||||
auto arg =
|
||||
p.tool_arg(p.tool_arg_open(arguments.name_prefix + p.tool_arg_name(p.literal(param_name)) +
|
||||
arguments.name_suffix) +
|
||||
arguments.value_prefix +
|
||||
(type == "string" ?
|
||||
p.tool_arg_string_value(p.schema(p.until(arguments.value_suffix),
|
||||
(schema_info.resolves_to_string(param_schema) ?
|
||||
p.tool_arg_string_value(p.schema(until_suffix,
|
||||
"tool-" + name + "-arg-" + param_name + "-schema",
|
||||
param_schema, true)) :
|
||||
p.tool_arg_json_value(p.schema(
|
||||
@@ -414,7 +392,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
for (const auto & opt : optional_parsers) {
|
||||
any_opt |= opt;
|
||||
}
|
||||
args_seq = args_seq + p.repeat(p.space() + any_opt, 0, (int) optional_parsers.size());
|
||||
args_seq = args_seq + p.repeat(p.space() + any_opt, 0, -1);
|
||||
}
|
||||
|
||||
if (!arguments.start.empty()) {
|
||||
|
||||
@@ -1124,7 +1124,7 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
p.rule("gemma4-bool", p.json_bool());
|
||||
p.rule("gemma4-null", p.json_null());
|
||||
p.rule("gemma4-number", p.json_number());
|
||||
p.rule("gemma4-dict-key", p.rule("gemma4-dict-key-name", p.until(":")) + p.literal(":"));
|
||||
p.rule("gemma4-dict-key", p.rule("gemma4-dict-key-name", p.chars("[^:}]", 1, -1)) + p.literal(":"));
|
||||
p.rule("gemma4-dict-kv", p.ref("gemma4-dict-key") + p.space() + p.ref("gemma4-value"));
|
||||
p.rule("gemma4-dict", [&]() {
|
||||
auto ws = p.space();
|
||||
@@ -1963,7 +1963,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
params.add_generation_prompt = true;
|
||||
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
|
||||
params.generation_prompt = diff.right;
|
||||
params.generation_prompt = diff.right + diff.suffix;
|
||||
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
|
||||
|
||||
@@ -591,6 +591,10 @@ static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
|
||||
for (const auto & f : files) {
|
||||
if (gguf_filename_is_model(f.path) &&
|
||||
std::regex_search(f.path, pattern)) {
|
||||
auto split = get_gguf_split_info(f.path);
|
||||
if (split.count > 1 && split.index != 1) {
|
||||
continue;
|
||||
}
|
||||
return f;
|
||||
}
|
||||
}
|
||||
@@ -600,6 +604,10 @@ static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
|
||||
if (tag.empty()) {
|
||||
for (const auto & f : files) {
|
||||
if (gguf_filename_is_model(f.path)) {
|
||||
auto split = get_gguf_split_info(f.path);
|
||||
if (split.count > 1 && split.index != 1) {
|
||||
continue;
|
||||
}
|
||||
return f;
|
||||
}
|
||||
}
|
||||
@@ -618,6 +626,7 @@ static void list_available_gguf_files(const hf_cache::hf_files & files) {
|
||||
}
|
||||
|
||||
struct hf_plan {
|
||||
hf_cache::hf_file primary;
|
||||
hf_cache::hf_files model_files;
|
||||
hf_cache::hf_file mmproj;
|
||||
};
|
||||
@@ -663,6 +672,7 @@ static hf_plan get_hf_plan(const common_params_model & model,
|
||||
}
|
||||
}
|
||||
|
||||
plan.primary = primary;
|
||||
plan.model_files = get_split_files(all, primary);
|
||||
|
||||
if (opts.download_mmproj) {
|
||||
@@ -749,7 +759,7 @@ common_download_model_result common_download_model(const common_params_model
|
||||
for (const auto & f : hf.model_files) {
|
||||
hf_cache::finalize_file(f);
|
||||
}
|
||||
result.model_path = hf.model_files[0].final_path;
|
||||
result.model_path = hf.primary.final_path;
|
||||
|
||||
if (!hf.mmproj.path.empty()) {
|
||||
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
|
||||
|
||||
@@ -251,6 +251,23 @@ value binary_expression::execute_impl(context & ctx) {
|
||||
return res;
|
||||
}
|
||||
|
||||
// Python-style string repetition
|
||||
// TODO: support array/tuple repetition (e.g., [1, 2] * 3 → [1, 2, 1, 2, 1, 2])
|
||||
if (op.value == "*" &&
|
||||
((is_val<value_string>(left_val) && is_val<value_int>(right_val)) ||
|
||||
(is_val<value_int>(left_val) && is_val<value_string>(right_val)))) {
|
||||
const auto & str = is_val<value_string>(left_val) ? left_val->as_string() : right_val->as_string();
|
||||
const int64_t repeat = is_val<value_int>(right_val) ? right_val->as_int() : left_val->as_int();
|
||||
auto res = mk_val<value_string>();
|
||||
if (repeat <= 0) {
|
||||
return res;
|
||||
}
|
||||
for (int64_t i = 0; i < repeat; ++i) {
|
||||
res->val_str = res->val_str.append(str);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
// String membership
|
||||
if (is_val<value_string>(left_val) && is_val<value_string>(right_val)) {
|
||||
// case: "a" in "abc"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "runtime.h"
|
||||
#include "unicode.h"
|
||||
#include "value.h"
|
||||
|
||||
// for converting from JSON to jinja values
|
||||
@@ -154,6 +155,83 @@ static value test_compare_fn(const func_args & args) {
|
||||
return mk_val<value_bool>(value_compare(args.get_pos(0), args.get_pos(1), op));
|
||||
}
|
||||
|
||||
static void append_codepoint_as_ascii_json_escape(std::string & out, uint32_t codepoint) {
|
||||
auto append_u16 = [&out](uint32_t value) {
|
||||
char buf[8];
|
||||
snprintf(buf, sizeof(buf), "\\u%04x", static_cast<unsigned int>(value));
|
||||
out += buf;
|
||||
};
|
||||
|
||||
if (codepoint <= 0xFFFF) {
|
||||
append_u16(codepoint);
|
||||
return;
|
||||
}
|
||||
|
||||
codepoint -= 0x10000;
|
||||
append_u16(0xD800 + ((codepoint >> 10) & 0x3FF));
|
||||
append_u16(0xDC00 + (codepoint & 0x3FF));
|
||||
}
|
||||
|
||||
static std::string json_ensure_ascii_preserving_format(const std::string & json_str) {
|
||||
std::string output;
|
||||
output.reserve(json_str.size());
|
||||
|
||||
bool in_string = false;
|
||||
bool escaped = false;
|
||||
|
||||
for (size_t pos = 0; pos < json_str.size();) {
|
||||
const char ch = json_str[pos];
|
||||
if (!in_string) {
|
||||
output.push_back(ch);
|
||||
if (ch == '"') {
|
||||
in_string = true;
|
||||
}
|
||||
++pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (escaped) {
|
||||
output.push_back(ch);
|
||||
escaped = false;
|
||||
++pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ch == '\\') {
|
||||
output.push_back(ch);
|
||||
escaped = true;
|
||||
++pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ch == '"') {
|
||||
output.push_back(ch);
|
||||
in_string = false;
|
||||
++pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
const unsigned char uch = static_cast<unsigned char>(ch);
|
||||
if (uch < 0x80) {
|
||||
output.push_back(ch);
|
||||
++pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto parsed = common_parse_utf8_codepoint(json_str, pos);
|
||||
if (parsed.status != utf8_parse_result::SUCCESS) {
|
||||
output += "\\ufffd";
|
||||
++pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
append_codepoint_as_ascii_json_escape(output, parsed.codepoint);
|
||||
pos += parsed.bytes_consumed;
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
static value tojson(const func_args & args) {
|
||||
args.ensure_count(1, 5);
|
||||
value val_ascii = args.get_kwarg_or_pos("ensure_ascii", 1);
|
||||
@@ -169,16 +247,17 @@ static value tojson(const func_args & args) {
|
||||
if (is_val<value_int>(val_indent)) {
|
||||
indent = static_cast<int>(val_indent->as_int());
|
||||
}
|
||||
if (val_ascii->as_bool()) { // undefined == false
|
||||
throw not_implemented_exception("tojson ensure_ascii=true not implemented");
|
||||
}
|
||||
if (val_sort->as_bool()) { // undefined == false
|
||||
throw not_implemented_exception("tojson sort_keys=true not implemented");
|
||||
}
|
||||
const bool ensure_ascii = val_ascii->as_bool(); // undefined == false
|
||||
auto separators = (is_val<value_array>(val_separators) ? val_separators : mk_val<value_array>())->as_array();
|
||||
std::string item_sep = separators.size() > 0 ? separators[0]->as_string().str() : (indent < 0 ? ", " : ",");
|
||||
std::string key_sep = separators.size() > 1 ? separators[1]->as_string().str() : ": ";
|
||||
std::string json_str = value_to_json(args.get_pos(0), indent, item_sep, key_sep);
|
||||
if (ensure_ascii) {
|
||||
json_str = json_ensure_ascii_preserving_format(json_str);
|
||||
}
|
||||
return mk_val<value_string>(json_str);
|
||||
}
|
||||
|
||||
@@ -460,6 +539,10 @@ const func_builtins & value_int_t::get_builtins() const {
|
||||
int64_t val = args.get_pos(0)->as_int();
|
||||
return mk_val<value_int>(val < 0 ? -val : val);
|
||||
}},
|
||||
{"int", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_int>();
|
||||
return mk_val<value_int>(args.get_pos(0)->as_int());
|
||||
}},
|
||||
{"float", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_int>();
|
||||
double val = static_cast<double>(args.get_pos(0)->as_int());
|
||||
@@ -486,6 +569,10 @@ const func_builtins & value_float_t::get_builtins() const {
|
||||
int64_t val = static_cast<int64_t>(args.get_pos(0)->as_float());
|
||||
return mk_val<value_int>(val);
|
||||
}},
|
||||
{"float", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_float>();
|
||||
return mk_val<value_float>(args.get_pos(0)->as_float());
|
||||
}},
|
||||
{"safe", tojson},
|
||||
{"string", tojson},
|
||||
{"tojson", tojson},
|
||||
|
||||
@@ -1229,15 +1229,15 @@ class TextModel(ModelBase):
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
|
||||
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
|
||||
assert max(tokenizer.vocab.values()) < vocab_size
|
||||
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) # ty: ignore[unresolved-attribute]
|
||||
assert max(tokenizer.vocab.values()) < vocab_size # ty: ignore[unresolved-attribute]
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} # ty: ignore[unresolved-attribute]
|
||||
added_vocab = tokenizer.get_added_vocab() # ty: ignore[unresolved-attribute]
|
||||
|
||||
added_tokens_decoder = tokenizer.added_tokens_decoder
|
||||
added_tokens_decoder = tokenizer.added_tokens_decoder # ty: ignore[unresolved-attribute]
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
@@ -1250,7 +1250,7 @@ class TextModel(ModelBase):
|
||||
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
|
||||
if not added_tokens_decoder[i].normalized:
|
||||
previous_token = token
|
||||
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
|
||||
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False)) # ty: ignore[unresolved-attribute, invalid-assignment]
|
||||
if previous_token != token:
|
||||
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
|
||||
|
||||
@@ -1583,13 +1583,13 @@ class TextModel(ModelBase):
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
|
||||
vocab_size = hparams["vocab_size"]
|
||||
assert max(tokenizer.get_vocab().values()) < vocab_size
|
||||
assert max(tokenizer.get_vocab().values()) < vocab_size # ty: ignore[unresolved-attribute]
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
merges = []
|
||||
vocab = {}
|
||||
mergeable_ranks = tokenizer.mergeable_ranks
|
||||
mergeable_ranks = tokenizer.mergeable_ranks # ty: ignore[unresolved-attribute]
|
||||
for token, rank in mergeable_ranks.items():
|
||||
vocab[QwenModel.token_bytes_to_string(token)] = rank
|
||||
if len(token) == 1:
|
||||
@@ -1599,7 +1599,7 @@ class TextModel(ModelBase):
|
||||
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
|
||||
|
||||
# for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
|
||||
added_vocab = tokenizer.special_tokens
|
||||
added_vocab = tokenizer.special_tokens # ty: ignore[unresolved-attribute]
|
||||
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()}
|
||||
|
||||
for i in range(vocab_size):
|
||||
@@ -1622,10 +1622,10 @@ class TextModel(ModelBase):
|
||||
special_vocab.merges = merges
|
||||
# only add special tokens when they were not already loaded from config.json
|
||||
if len(special_vocab.special_token_ids) == 0:
|
||||
special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"])
|
||||
special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"])
|
||||
special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
# this one is usually not in config.json anyway
|
||||
special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"])
|
||||
special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_sentencepiece(self, add_to_gguf=True):
|
||||
@@ -1877,10 +1877,10 @@ class TextModel(ModelBase):
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_glm(self):
|
||||
@@ -1894,10 +1894,10 @@ class TextModel(ModelBase):
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
# Special tokens
|
||||
# Note: Using <|endoftext|> (151329) for eot causes endless generation
|
||||
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
|
||||
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
|
||||
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # ty: ignore[unresolved-attribute] # 151331
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # ty: ignore[unresolved-attribute] # 151336
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute] # 151329
|
||||
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # ty: ignore[unresolved-attribute] # 151338
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_interns1(self):
|
||||
@@ -1906,16 +1906,16 @@ class TextModel(ModelBase):
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
|
||||
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
|
||||
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab()) # ty: ignore[unresolved-attribute]
|
||||
vocab_size = self.hparams.get("vocab_size", len(vocab))
|
||||
assert max(vocab.values()) < vocab_size
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
added_vocab = tokenizer.get_added_vocab() # ty: ignore[unresolved-attribute]
|
||||
|
||||
added_tokens_decoder = tokenizer.added_tokens_decoder
|
||||
added_tokens_decoder = tokenizer.added_tokens_decoder # ty: ignore[unresolved-attribute]
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
@@ -1928,7 +1928,7 @@ class TextModel(ModelBase):
|
||||
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
|
||||
if not added_tokens_decoder[i].normalized:
|
||||
previous_token = token
|
||||
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
|
||||
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False)) # ty: ignore[unresolved-attribute, invalid-assignment]
|
||||
if previous_token != token:
|
||||
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
|
||||
|
||||
@@ -2516,15 +2516,15 @@ class XverseModel(TextModel):
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model)
|
||||
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
|
||||
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab)) # ty: ignore[unresolved-attribute]
|
||||
# Since we are checking the maximum index, we need to ensure it's strictly less than vocab_size,
|
||||
# because vocab_size is the count of items, and indexes start at 0.
|
||||
max_vocab_index = max(tokenizer.get_vocab().values())
|
||||
max_vocab_index = max(tokenizer.get_vocab().values()) # ty: ignore[unresolved-attribute]
|
||||
if max_vocab_index >= vocab_size:
|
||||
raise ValueError("Vocabulary size exceeds expected maximum size.")
|
||||
|
||||
reverse_vocab: dict[int, str] = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
reverse_vocab: dict[int, str] = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} # ty: ignore[unresolved-attribute]
|
||||
added_vocab = tokenizer.get_added_vocab() # ty: ignore[unresolved-attribute]
|
||||
|
||||
for token_id in range(vocab_size):
|
||||
token_text = reverse_vocab[token_id].encode('utf-8')
|
||||
@@ -2535,7 +2535,7 @@ class XverseModel(TextModel):
|
||||
elif re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
|
||||
toktype = gguf.TokenType.BYTE # special
|
||||
elif reverse_vocab[token_id] in added_vocab:
|
||||
if tokenizer.added_tokens_decoder[token_id].special:
|
||||
if tokenizer.added_tokens_decoder[token_id].special: # ty: ignore[unresolved-attribute]
|
||||
toktype = gguf.TokenType.CONTROL
|
||||
else:
|
||||
toktype = gguf.TokenType.USER_DEFINED
|
||||
@@ -3752,7 +3752,7 @@ class QwenModel(TextModel):
|
||||
|
||||
@staticmethod
|
||||
def token_bytes_to_string(b):
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode # ty: ignore[unresolved-import]
|
||||
byte_encoder = bytes_to_unicode()
|
||||
return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
|
||||
|
||||
@@ -3777,7 +3777,14 @@ class QwenModel(TextModel):
|
||||
self._set_vocab_qwen()
|
||||
|
||||
|
||||
@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration", "KORMoForCausalLM", "AudioFlamingo3ForConditionalGeneration")
|
||||
@ModelBase.register(
|
||||
"Qwen2Model",
|
||||
"Qwen2ForCausalLM",
|
||||
"Qwen2AudioForConditionalGeneration",
|
||||
"KORMoForCausalLM",
|
||||
"AudioFlamingo3ForConditionalGeneration",
|
||||
"DotsOCRForCausalLM",
|
||||
)
|
||||
class Qwen2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN2
|
||||
|
||||
@@ -3798,7 +3805,8 @@ class Qwen2Model(TextModel):
|
||||
name = name.replace("language_model.", "") # for InternVL
|
||||
if name.startswith("mlp") or name.startswith("multi_modal_projector") \
|
||||
or name.startswith("vision_model") or name.startswith("audio_tower") \
|
||||
or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):
|
||||
or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector") \
|
||||
or name.startswith("vision_tower."):
|
||||
# skip vision and audio tensors
|
||||
return
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
@@ -3815,14 +3823,14 @@ class DreamModel(TextModel):
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
|
||||
|
||||
vocab_dict = tokenizer.get_vocab()
|
||||
vocab_dict = tokenizer.get_vocab() # ty: ignore[unresolved-attribute]
|
||||
vocab_size = self.hparams.get("vocab_size", len(vocab_dict))
|
||||
assert max(vocab_dict.values()) < vocab_size
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab_dict.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
added_vocab = tokenizer.get_added_vocab() # ty: ignore[unresolved-attribute]
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
@@ -3880,14 +3888,14 @@ class LLaDAModel(TextModel):
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
|
||||
|
||||
vocab_dict = tokenizer.get_vocab()
|
||||
vocab_dict = tokenizer.get_vocab() # ty: ignore[unresolved-attribute]
|
||||
vocab_size = self.hparams.get("vocab_size", len(vocab_dict))
|
||||
assert max(vocab_dict.values()) < vocab_size
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab_dict.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
added_vocab = tokenizer.get_added_vocab() # ty: ignore[unresolved-attribute]
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
@@ -4665,9 +4673,9 @@ class Qwen3Model(Qwen2Model):
|
||||
|
||||
self.is_rerank = True
|
||||
self.is_tied_embeddings = self.hparams.get("tie_word_embeddings", False)
|
||||
self.token_false_id = tokenizer.convert_tokens_to_ids("no")
|
||||
self.token_true_id = tokenizer.convert_tokens_to_ids("yes")
|
||||
self.sep_token_id = tokenizer.convert_tokens_to_ids("|")
|
||||
self.token_false_id = tokenizer.convert_tokens_to_ids("no") # ty: ignore[unresolved-attribute, invalid-assignment]
|
||||
self.token_true_id = tokenizer.convert_tokens_to_ids("yes") # ty: ignore[unresolved-attribute, invalid-assignment]
|
||||
self.sep_token_id = tokenizer.convert_tokens_to_ids("|") # ty: ignore[unresolved-attribute]
|
||||
|
||||
assert self.token_false_id is not None and self.token_true_id is not None
|
||||
|
||||
@@ -5936,7 +5944,7 @@ class KimiLinearModel(TextModel):
|
||||
# Build merges list using the approach similar to HunYuanMoE
|
||||
merges = []
|
||||
vocab = {}
|
||||
mergeable_ranks = tokenizer.model._mergeable_ranks
|
||||
mergeable_ranks = tokenizer.model._mergeable_ranks # ty: ignore[unresolved-attribute]
|
||||
for token, rank in mergeable_ranks.items():
|
||||
vocab[QwenModel.token_bytes_to_string(token)] = rank
|
||||
if len(token) == 1:
|
||||
@@ -5946,7 +5954,7 @@ class KimiLinearModel(TextModel):
|
||||
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
|
||||
# Build token list
|
||||
vocab_size = self.hparams["vocab_size"]
|
||||
special_tokens = tokenizer.special_tokens
|
||||
special_tokens = tokenizer.special_tokens # ty: ignore[unresolved-attribute]
|
||||
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
@@ -5972,7 +5980,7 @@ class KimiLinearModel(TextModel):
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
# override eos id in config.json with tiktoken eos id
|
||||
self.gguf_writer.add_eos_token_id(tokenizer.eos_id)
|
||||
self.gguf_writer.add_eos_token_id(tokenizer.eos_id) # ty: ignore[unresolved-attribute]
|
||||
else:
|
||||
raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!")
|
||||
|
||||
@@ -6466,11 +6474,11 @@ class BertModel(TextModel):
|
||||
with open(tokenizer_config_path, "r", encoding="utf-8") as fp:
|
||||
tokenizer_config_json = json.load(fp)
|
||||
|
||||
add_prefix = tokenizer.add_prefix_space
|
||||
remove_whitespaces = tokenizer.clean_up_tokenization_spaces
|
||||
add_prefix = tokenizer.add_prefix_space # ty: ignore[unresolved-attribute]
|
||||
remove_whitespaces = tokenizer.clean_up_tokenization_spaces # ty: ignore[unresolved-attribute]
|
||||
precompiled_charsmap = b64decode(tokenizer_json["normalizer"]["precompiled_charsmap"])
|
||||
|
||||
vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size)
|
||||
vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size) # ty: ignore[unresolved-attribute]
|
||||
else:
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
|
||||
@@ -6487,7 +6495,7 @@ class BertModel(TextModel):
|
||||
|
||||
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
|
||||
scores: list[float] = [-10000.0] * vocab_size
|
||||
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
|
||||
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size # ty: ignore[invalid-assignment]
|
||||
|
||||
if isinstance(tokenizer, SentencePieceProcessor):
|
||||
for token_id in range(tokenizer.vocab_size()):
|
||||
@@ -6509,20 +6517,20 @@ class BertModel(TextModel):
|
||||
scores[token_id] = score
|
||||
toktypes[token_id] = toktype
|
||||
else:
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
added_vocab = tokenizer.get_added_vocab() # ty: ignore[unresolved-attribute]
|
||||
unk_token = tokenizer_config_json.get("unk_token")
|
||||
unk_token_id = added_vocab.get(unk_token, tokenizer_json["model"].get("unk_id", 3))
|
||||
unk_token_id = added_vocab.get(unk_token, tokenizer_json["model"].get("unk_id", 3)) # ty: ignore[no-matching-overload]
|
||||
|
||||
for token_id in range(tokenizer.vocab_size):
|
||||
piece = tokenizer._convert_id_to_token(token_id)
|
||||
if (piece := tokenizer._convert_id_to_token(token_id)) is not None:
|
||||
for token_id in range(tokenizer.vocab_size): # ty: ignore[unresolved-attribute]
|
||||
piece = tokenizer._convert_id_to_token(token_id) # ty: ignore[unresolved-attribute]
|
||||
if (piece := tokenizer._convert_id_to_token(token_id)) is not None: # ty: ignore[unresolved-attribute]
|
||||
text = piece.encode("utf-8")
|
||||
score = tokenizer_json["model"]["vocab"][token_id][1]
|
||||
|
||||
toktype = SentencePieceTokenTypes.NORMAL
|
||||
if token_id == unk_token_id:
|
||||
toktype = SentencePieceTokenTypes.UNKNOWN
|
||||
elif token_id in tokenizer.all_special_ids:
|
||||
elif token_id in tokenizer.all_special_ids: # ty: ignore[unresolved-attribute]
|
||||
toktype = SentencePieceTokenTypes.CONTROL
|
||||
elif token_id in added_vocab.values():
|
||||
toktype = SentencePieceTokenTypes.USER_DEFINED
|
||||
@@ -8831,7 +8839,7 @@ class DeepseekV2Model(TextModel):
|
||||
# Build merges list using the approach similar to HunYuanMoE
|
||||
merges = []
|
||||
vocab = {}
|
||||
mergeable_ranks = tokenizer.model._mergeable_ranks
|
||||
mergeable_ranks = tokenizer.model._mergeable_ranks # ty: ignore[unresolved-attribute]
|
||||
for token, rank in mergeable_ranks.items():
|
||||
vocab[QwenModel.token_bytes_to_string(token)] = rank
|
||||
if len(token) == 1:
|
||||
@@ -8842,7 +8850,7 @@ class DeepseekV2Model(TextModel):
|
||||
|
||||
# Build token list
|
||||
vocab_size = self.hparams["vocab_size"]
|
||||
special_tokens = tokenizer.special_tokens
|
||||
special_tokens = tokenizer.special_tokens # ty: ignore[unresolved-attribute]
|
||||
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
@@ -9813,10 +9821,10 @@ class Glm4Model(TextModel):
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
@@ -10044,12 +10052,12 @@ class ChatGLMModel(TextModel):
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
|
||||
vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab()))
|
||||
assert max(tokenizer.get_vocab().values()) < vocab_size
|
||||
vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab())) # ty: ignore[unresolved-attribute]
|
||||
assert max(tokenizer.get_vocab().values()) < vocab_size # ty: ignore[unresolved-attribute]
|
||||
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
|
||||
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
|
||||
for token_id in range(vocab_size):
|
||||
piece = tokenizer._convert_id_to_token(token_id)
|
||||
piece = tokenizer._convert_id_to_token(token_id) # ty: ignore[unresolved-attribute]
|
||||
if token_id == 0:
|
||||
piece = "<unk>"
|
||||
elif token_id == 1:
|
||||
@@ -10057,17 +10065,17 @@ class ChatGLMModel(TextModel):
|
||||
elif token_id == 2:
|
||||
piece = "<eos>"
|
||||
|
||||
text = piece.encode("utf-8")
|
||||
text = piece.encode("utf-8") # ty: ignore[unresolved-attribute]
|
||||
score = 0.0
|
||||
# Referencing the tokenizer Python implementation(https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py),
|
||||
# it is only valid if it is less than tokenizer.tokenizer.sp_model.vocab_size()
|
||||
if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size():
|
||||
score = tokenizer.tokenizer.sp_model.get_score(token_id)
|
||||
if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size(): # ty: ignore[unresolved-attribute, invalid-argument-type]
|
||||
score = tokenizer.tokenizer.sp_model.get_score(token_id) # ty: ignore[unresolved-attribute]
|
||||
|
||||
if token_id >= tokenizer.tokenizer.sp_model.vocab_size():
|
||||
if token_id >= tokenizer.tokenizer.sp_model.vocab_size(): # ty: ignore[unresolved-attribute]
|
||||
if piece in special_tokens:
|
||||
toktype = SentencePieceTokenTypes.CONTROL
|
||||
elif len(piece) == 0:
|
||||
elif len(piece) == 0: # ty: ignore[invalid-argument-type]
|
||||
text = f"[PAD{token_id}]".encode("utf-8")
|
||||
toktype = SentencePieceTokenTypes.UNUSED
|
||||
else:
|
||||
@@ -10078,13 +10086,13 @@ class ChatGLMModel(TextModel):
|
||||
continue
|
||||
|
||||
toktype = SentencePieceTokenTypes.NORMAL
|
||||
if tokenizer.tokenizer.sp_model.is_unknown(token_id):
|
||||
if tokenizer.tokenizer.sp_model.is_unknown(token_id): # ty: ignore[unresolved-attribute]
|
||||
toktype = SentencePieceTokenTypes.UNKNOWN
|
||||
elif tokenizer.tokenizer.sp_model.is_control(token_id):
|
||||
elif tokenizer.tokenizer.sp_model.is_control(token_id): # ty: ignore[unresolved-attribute]
|
||||
toktype = SentencePieceTokenTypes.CONTROL
|
||||
elif tokenizer.tokenizer.sp_model.is_unused(token_id):
|
||||
elif tokenizer.tokenizer.sp_model.is_unused(token_id): # ty: ignore[unresolved-attribute]
|
||||
toktype = SentencePieceTokenTypes.UNUSED
|
||||
elif tokenizer.tokenizer.sp_model.is_byte(token_id):
|
||||
elif tokenizer.tokenizer.sp_model.is_byte(token_id): # ty: ignore[unresolved-attribute]
|
||||
toktype = SentencePieceTokenTypes.BYTE
|
||||
|
||||
tokens.append(text)
|
||||
@@ -10104,7 +10112,7 @@ class ChatGLMModel(TextModel):
|
||||
|
||||
@staticmethod
|
||||
def token_bytes_to_string(b):
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode # ty: ignore[unresolved-import]
|
||||
byte_encoder = bytes_to_unicode()
|
||||
return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
|
||||
|
||||
@@ -10138,7 +10146,7 @@ class ChatGLMModel(TextModel):
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
|
||||
vocab_size = hparams.get("padded_vocab_size",hparams["vocab_size"])
|
||||
assert max(tokenizer.get_vocab().values()) < vocab_size
|
||||
assert max(tokenizer.get_vocab().values()) < vocab_size # ty: ignore[unresolved-attribute]
|
||||
|
||||
tokens, toktypes, tokpre = self.get_vocab_base()
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
@@ -10147,10 +10155,10 @@ class ChatGLMModel(TextModel):
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
# only add special tokens when they were not already loaded from config.json
|
||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
|
||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # ty: ignore[unresolved-attribute]
|
||||
# this one is usually not in config.json anyway
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
@@ -11416,7 +11424,7 @@ class HunYuanMoEModel(TextModel):
|
||||
# 2. Reverse-engineer the merges list from mergeable_ranks
|
||||
merges = []
|
||||
vocab = {}
|
||||
mergeable_ranks = tokenizer.mergeable_ranks
|
||||
mergeable_ranks = tokenizer.mergeable_ranks # ty: ignore[unresolved-attribute]
|
||||
for token, rank in mergeable_ranks.items():
|
||||
vocab[QwenModel.token_bytes_to_string(token)] = rank
|
||||
if len(token) == 1:
|
||||
@@ -11427,8 +11435,8 @@ class HunYuanMoEModel(TextModel):
|
||||
|
||||
# 3. Generate the tokens and toktypes lists
|
||||
vocab_size = self.hparams["vocab_size"]
|
||||
assert tokenizer.vocab_size == vocab_size
|
||||
special_tokens = tokenizer.special_tokens
|
||||
assert tokenizer.vocab_size == vocab_size # ty: ignore[unresolved-attribute]
|
||||
special_tokens = tokenizer.special_tokens # ty: ignore[unresolved-attribute]
|
||||
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
@@ -11652,7 +11660,7 @@ class HunYuanModel(TextModel):
|
||||
# 2. Reverse-engineer the merges list from mergeable_ranks
|
||||
merges = []
|
||||
vocab = {}
|
||||
mergeable_ranks = tokenizer.mergeable_ranks
|
||||
mergeable_ranks = tokenizer.mergeable_ranks # ty: ignore[unresolved-attribute]
|
||||
for token, rank in mergeable_ranks.items():
|
||||
vocab[QwenModel.token_bytes_to_string(token)] = rank
|
||||
if len(token) == 1:
|
||||
@@ -11663,8 +11671,8 @@ class HunYuanModel(TextModel):
|
||||
|
||||
# 3. Generate the tokens and toktypes lists
|
||||
vocab_size = self.hparams["vocab_size"]
|
||||
assert tokenizer.vocab_size == vocab_size
|
||||
special_tokens = tokenizer.special_tokens
|
||||
assert tokenizer.vocab_size == vocab_size # ty: ignore[unresolved-attribute]
|
||||
special_tokens = tokenizer.special_tokens # ty: ignore[unresolved-attribute]
|
||||
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
@@ -12812,13 +12820,44 @@ class SolarOpenModel(Glm4MoeModel):
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<unk>"])
|
||||
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|startoftext|>"])
|
||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|endoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<unk>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|startoftext|>"]) # ty: ignore[unresolved-attribute]
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
|
||||
@ModelBase.register("DotsOCRForCausalLM")
|
||||
class DotsOCRVisionModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams_vision is not None
|
||||
self.hparams_vision["image_size"] = 0 # dynamic resolution
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.DOTSOCR)
|
||||
self.gguf_writer.add_vision_min_pixels(self.preprocessor_config["min_pixels"])
|
||||
self.gguf_writer.add_vision_max_pixels(self.preprocessor_config["max_pixels"])
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(self.find_vparam(["rms_norm_eps"]))
|
||||
self.gguf_writer.add_vision_projector_scale_factor(self.find_vparam(["spatial_merge_size"]))
|
||||
self.gguf_writer.add_vision_use_silu(True)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name.startswith("vision_tower."):
|
||||
if "vision_tower.blocks." in name and ".mlp." in name:
|
||||
# note: to avoid naming conflicts in tensor_mapping.py, we need to handle FFN renaming here
|
||||
# x = F.silu(self.fc1(x)) * self.fc3(x)
|
||||
# x = self.fc2(x)
|
||||
# fc1 -> gate, fc2 -> down, fc3 -> up
|
||||
# mapping original names to Qwen2.5 naming scheme
|
||||
name = name.replace("vision_tower.blocks.", "visual.blocks.")
|
||||
name = name.replace(".fc1", ".gate_proj")
|
||||
name = name.replace(".fc2", ".down_proj")
|
||||
name = name.replace(".fc3", ".up_proj")
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
||||
@@ -296,7 +296,7 @@ for model in [*pre_computed_hashes, *all_models]:
|
||||
except Exception as e:
|
||||
raise OSError(f"Error loading tokenizer for model {name}.") from e
|
||||
|
||||
chktok = tokenizer.encode(CHK_TXT)
|
||||
chktok = tokenizer.encode(CHK_TXT) # ty: ignore[unresolved-attribute]
|
||||
chkhsh = sha256(str(chktok).encode()).hexdigest()
|
||||
|
||||
logger.info(f"model: {name}")
|
||||
@@ -468,7 +468,7 @@ for model in models:
|
||||
|
||||
with open(f"models/ggml-vocab-{name}.gguf.out", "w") as f:
|
||||
for text in tests:
|
||||
res = tokenizer.encode(text, add_special_tokens=False)
|
||||
res = tokenizer.encode(text, add_special_tokens=False) # ty: ignore[unresolved-attribute]
|
||||
for r in res:
|
||||
f.write(f" {r}")
|
||||
f.write("\n")
|
||||
|
||||
@@ -402,7 +402,7 @@ if __name__ == '__main__':
|
||||
# the invocation string includes the "<|start_of_turn|>"
|
||||
# token, but the adapters themselves were trained to
|
||||
# activate _after_ that first token, so we drop it here.
|
||||
alora_invocation_tokens = tokenizer(invocation_string)["input_ids"][1:]
|
||||
alora_invocation_tokens = tokenizer(invocation_string)["input_ids"][1:] # ty: ignore[call-non-callable]
|
||||
if alora_invocation_tokens:
|
||||
logger.debug("GGUF KV: %s = %s", gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS, alora_invocation_tokens)
|
||||
self.gguf_writer.add_key_value(
|
||||
|
||||
@@ -37,6 +37,7 @@ llama-server -hf ggml-org/gemma-3-4b-it-GGUF --no-mmproj-offload
|
||||
> - PaddleOCR-VL: https://github.com/ggml-org/llama.cpp/pull/18825
|
||||
> - GLM-OCR: https://github.com/ggml-org/llama.cpp/pull/19677
|
||||
> - Deepseek-OCR: https://github.com/ggml-org/llama.cpp/pull/17400
|
||||
> - Dots.OCR: https://github.com/ggml-org/llama.cpp/pull/17575
|
||||
> - HunyuanOCR: https://github.com/ggml-org/llama.cpp/pull/21395
|
||||
|
||||
## Pre-quantized models
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <vector>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <optional>
|
||||
#include <regex>
|
||||
|
||||
static void print_usage(int /*argc*/, char ** argv) {
|
||||
@@ -222,7 +223,10 @@ int main(int argc, char ** argv) {
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
base_callback_data cb_data(params, params.tensor_filter);
|
||||
std::optional<base_callback_data> cb_data;
|
||||
if (!params.save_logits) {
|
||||
cb_data.emplace(params, params.tensor_filter);
|
||||
}
|
||||
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
|
||||
@@ -53,10 +53,10 @@ model_name = os.path.basename(model_path)
|
||||
print(f"Model name: {model_name}")
|
||||
|
||||
prompt = "Hello world today"
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids # ty: ignore[call-non-callable]
|
||||
print(f"Input tokens: {input_ids}")
|
||||
print(f"Input text: {repr(prompt)}")
|
||||
print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
|
||||
print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}") # ty: ignore[unresolved-attribute]
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids, output_hidden_states=True)
|
||||
@@ -92,7 +92,7 @@ with torch.no_grad():
|
||||
|
||||
# Print embeddings per token in the requested format
|
||||
print("\nToken embeddings:")
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) # ty: ignore[unresolved-attribute]
|
||||
for i, embedding in enumerate(token_embeddings):
|
||||
# Format: show first few values, ..., then last few values
|
||||
if len(embedding) > 10:
|
||||
|
||||
@@ -207,8 +207,8 @@ def main():
|
||||
else:
|
||||
model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True)
|
||||
|
||||
encoded = tokenizer(prompt, return_tensors="pt")
|
||||
tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])
|
||||
encoded = tokenizer(prompt, return_tensors="pt") # ty: ignore[call-non-callable]
|
||||
tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0]) # ty: ignore[unresolved-attribute]
|
||||
n_tokens = len(tokens)
|
||||
print(f"n_tokens: {n_tokens}");
|
||||
print(f"hidden_size: {model.config.hidden_size}")
|
||||
|
||||
@@ -60,24 +60,24 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
if (nrows == 1) {
|
||||
DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
offset_iterator, offset_iterator + 1, stream);
|
||||
offset_iterator, offset_iterator + 1, stream));
|
||||
}
|
||||
} else {
|
||||
if (nrows == 1) {
|
||||
DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||
dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
|
||||
stream);
|
||||
stream));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,22 +86,22 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
if (nrows == 1) {
|
||||
DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||
ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream);
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||
ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream));
|
||||
}
|
||||
} else {
|
||||
if (nrows == 1) {
|
||||
DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
|
||||
offset_iterator + 1, stream);
|
||||
offset_iterator + 1, stream));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1173,7 +1173,11 @@ struct ggml_cuda_graph {
|
||||
std::vector<cudaGraphNode_t> nodes;
|
||||
bool disable_due_to_gpu_arch = false;
|
||||
bool warmup_complete = false;
|
||||
std::vector<ggml_tensor> nodes_copy;
|
||||
struct node_properties {
|
||||
ggml_tensor node;
|
||||
void * node_src_data_ptrs[GGML_MAX_SRC];
|
||||
};
|
||||
std::vector<node_properties> node_props;
|
||||
|
||||
bool is_enabled() const {
|
||||
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
||||
|
||||
@@ -2979,18 +2979,25 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
|
||||
// Check if the graph size has changed
|
||||
if ((int)graph->nodes_copy.size() != cgraph->n_nodes) {
|
||||
if ((int)graph->node_props.size() != cgraph->n_nodes) {
|
||||
res = true;
|
||||
graph->nodes_copy.resize(cgraph->n_nodes);
|
||||
graph->node_props.resize(cgraph->n_nodes);
|
||||
}
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (!res) {
|
||||
if (memcmp(&graph->nodes_copy[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
|
||||
res = true;
|
||||
}
|
||||
ggml_cuda_graph::node_properties prop = {};
|
||||
memcpy(&prop.node, cgraph->nodes[i], sizeof(ggml_tensor));
|
||||
|
||||
// if the backend scheduler is making copies of CPU tensors, the src pointers can be the same but with different data, see:
|
||||
// https://github.com/ggml-org/llama.cpp/pull/21472#discussion_r3052235188
|
||||
for (int j = 0; j < GGML_MAX_SRC; ++j) {
|
||||
prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j] ? cgraph->nodes[i]->src[j]->data : nullptr;
|
||||
}
|
||||
memcpy(&graph->nodes_copy[i], cgraph->nodes[i], sizeof(ggml_tensor));
|
||||
|
||||
if (!res && memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) {
|
||||
res = true;
|
||||
}
|
||||
graph->node_props[i] = prop;
|
||||
}
|
||||
|
||||
return res;
|
||||
|
||||
@@ -25,14 +25,14 @@ static void top_k_cub(ggml_cuda_pool & pool,
|
||||
auto indexes_in = cuda::make_counting_iterator(0);
|
||||
|
||||
size_t temp_storage_bytes = 0;
|
||||
DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k,
|
||||
env);
|
||||
CUDA_CHECK(DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k,
|
||||
env));
|
||||
|
||||
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
|
||||
void * d_temp_storage = temp_storage_alloc.get();
|
||||
|
||||
DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst,
|
||||
ncols, k, env);
|
||||
CUDA_CHECK(DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst,
|
||||
ncols, k, env));
|
||||
}
|
||||
|
||||
#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE
|
||||
|
||||
@@ -736,6 +736,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta
|
||||
suffix = ne00 % 4 == 0 ? "_4" : "";
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q1_0:
|
||||
{
|
||||
nsg = N_SG_Q1_0;
|
||||
nr0 = N_R0_Q1_0;
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
nsg = N_SG_Q4_0;
|
||||
@@ -948,6 +953,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m
|
||||
smem = 32*sizeof(float)*nr0;
|
||||
suffix = ne00 % 4 == 0 ? "_4" : "";
|
||||
} break;
|
||||
case GGML_TYPE_Q1_0:
|
||||
{
|
||||
nsg = N_SG_Q1_0;
|
||||
nr0 = N_R0_Q1_0;
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
nsg = N_SG_Q4_0;
|
||||
|
||||
@@ -1184,6 +1184,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
@@ -1210,6 +1211,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
|
||||
@@ -8,6 +8,9 @@
|
||||
//
|
||||
// TODO: for optimal performance, become function of the device and work size
|
||||
|
||||
#define N_R0_Q1_0 8
|
||||
#define N_SG_Q1_0 2
|
||||
|
||||
#define N_R0_Q4_0 4
|
||||
#define N_SG_Q4_0 2
|
||||
|
||||
|
||||
@@ -2047,6 +2047,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
|
||||
op->src[0]->type == GGML_TYPE_F16 ||
|
||||
op->src[0]->type == GGML_TYPE_BF16 ||
|
||||
op->src[0]->type == GGML_TYPE_Q1_0 ||
|
||||
op->src[0]->type == GGML_TYPE_Q4_0 ||
|
||||
op->src[0]->type == GGML_TYPE_Q4_1 ||
|
||||
op->src[0]->type == GGML_TYPE_Q5_0 ||
|
||||
|
||||
@@ -118,6 +118,56 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint8_t * qs = xb->qs;
|
||||
const float d = xb->d;
|
||||
const float neg_d = -d;
|
||||
|
||||
const int byte_offset = il * 2; // il*16 bits = il*2 bytes
|
||||
const uint8_t b0 = qs[byte_offset];
|
||||
const uint8_t b1 = qs[byte_offset + 1];
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
|
||||
reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
|
||||
reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
|
||||
reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
|
||||
reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
|
||||
reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
|
||||
reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
|
||||
reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));
|
||||
|
||||
reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
|
||||
reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
|
||||
reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
|
||||
reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
|
||||
reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
|
||||
reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
|
||||
reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
|
||||
reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
|
||||
const float d = xb->d;
|
||||
const float neg_d = -d;
|
||||
const int base = il * 4;
|
||||
const uint8_t byte = xb->qs[base / 8];
|
||||
const int s = base % 8;
|
||||
|
||||
float4 reg_f;
|
||||
reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1));
|
||||
reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1));
|
||||
reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1));
|
||||
reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1));
|
||||
|
||||
reg = (type4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
||||
@@ -152,6 +202,23 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q1_0(device const float * src, device block_q1_0 & dst) {
|
||||
float sum_abs = 0.0f;
|
||||
for (int j = 0; j < QK1_0; j++) {
|
||||
sum_abs += fabs(src[j]);
|
||||
}
|
||||
dst.d = sum_abs / QK1_0;
|
||||
|
||||
for (int j = 0; j < QK1_0 / 8; j++) {
|
||||
dst.qs[j] = 0;
|
||||
}
|
||||
for (int j = 0; j < QK1_0; j++) {
|
||||
if (src[j] >= 0.0f) {
|
||||
dst.qs[j / 8] |= (1 << (j % 8));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
||||
#pragma METAL fp math_mode(safe)
|
||||
float amax = 0.0f; // absolute max
|
||||
@@ -3116,6 +3183,35 @@ kernel void kernel_group_norm_f32(
|
||||
}
|
||||
}
|
||||
|
||||
// Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy)
|
||||
inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
|
||||
device const uint8_t * qs = qb_curr->qs + il / 8;
|
||||
const uint8_t b0 = qs[0];
|
||||
const uint8_t b1 = qs[1];
|
||||
|
||||
float acc = 0.0f;
|
||||
|
||||
acc += select(0.0f, yl[ 0], bool(b0 & 0x01));
|
||||
acc += select(0.0f, yl[ 1], bool(b0 & 0x02));
|
||||
acc += select(0.0f, yl[ 2], bool(b0 & 0x04));
|
||||
acc += select(0.0f, yl[ 3], bool(b0 & 0x08));
|
||||
acc += select(0.0f, yl[ 4], bool(b0 & 0x10));
|
||||
acc += select(0.0f, yl[ 5], bool(b0 & 0x20));
|
||||
acc += select(0.0f, yl[ 6], bool(b0 & 0x40));
|
||||
acc += select(0.0f, yl[ 7], bool(b0 & 0x80));
|
||||
|
||||
acc += select(0.0f, yl[ 8], bool(b1 & 0x01));
|
||||
acc += select(0.0f, yl[ 9], bool(b1 & 0x02));
|
||||
acc += select(0.0f, yl[10], bool(b1 & 0x04));
|
||||
acc += select(0.0f, yl[11], bool(b1 & 0x08));
|
||||
acc += select(0.0f, yl[12], bool(b1 & 0x10));
|
||||
acc += select(0.0f, yl[13], bool(b1 & 0x20));
|
||||
acc += select(0.0f, yl[14], bool(b1 & 0x40));
|
||||
acc += select(0.0f, yl[15], bool(b1 & 0x80));
|
||||
|
||||
return qb_curr->d * (2.0f * acc - sumy);
|
||||
}
|
||||
|
||||
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
||||
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
||||
// we assume that the yl's have been multiplied with the appropriate scale factor
|
||||
@@ -3337,6 +3433,85 @@ void mul_vec_q_n_f32_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template<int nr0, typename args_t>
|
||||
void kernel_mul_mv_q1_0_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
const short NSG = FC_mul_mv_nsg;
|
||||
|
||||
const int nb = args.ne00/QK1_0;
|
||||
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int first_row = (r0 * NSG + sgitg) * nr0;
|
||||
|
||||
const uint i12 = im%args.ne12;
|
||||
const uint i13 = im/args.ne12;
|
||||
|
||||
const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13;
|
||||
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
device const block_q1_0 * ax[nr0];
|
||||
for (int row = 0; row < nr0; ++row) {
|
||||
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0);
|
||||
}
|
||||
|
||||
float yl[16];
|
||||
float sumf[nr0] = {0.f};
|
||||
|
||||
const short ix = (tiisg/8);
|
||||
const short il = (tiisg%8)*16;
|
||||
|
||||
device const float * yb = y + ix*QK1_0 + il;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
|
||||
float sumy = 0.f;
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
yl[i] = yb[i];
|
||||
sumy += yb[i];
|
||||
}
|
||||
|
||||
FOR_UNROLL (short row = 0; row < nr0; row++) {
|
||||
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il);
|
||||
}
|
||||
|
||||
yb += QK1_0 * (N_SIMDWIDTH/8);
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < nr0; ++row) {
|
||||
const float tot = simd_sum(sumf[row]);
|
||||
|
||||
if (tiisg == 0 && first_row + row < args.ne01) {
|
||||
dst_f32[first_row + row] = tot;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_q1_0_f32")]]
|
||||
kernel void kernel_mul_mv_q1_0_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mv_q4_0_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
@@ -3729,6 +3904,11 @@ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4
|
||||
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>;
|
||||
#endif
|
||||
|
||||
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q1_0, 128, dequantize_q1_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q1_0, 128, dequantize_q1_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q1_0, 128, dequantize_q1_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q1_0, 128, dequantize_q1_0_t4>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||
@@ -7133,6 +7313,7 @@ kernel void kernel_cpy_f32_q(
|
||||
typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
|
||||
|
||||
template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
|
||||
template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK1_0, block_q1_0, quantize_q1_0>;
|
||||
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
|
||||
template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
|
||||
template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
|
||||
@@ -7173,12 +7354,14 @@ kernel void kernel_cpy_q_f32(
|
||||
|
||||
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
|
||||
|
||||
template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q1_0, 8, dequantize_q1_0>;
|
||||
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
|
||||
|
||||
template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q1_0, 8, dequantize_q1_0>;
|
||||
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
|
||||
@@ -9776,6 +9959,7 @@ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_ro
|
||||
|
||||
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
||||
|
||||
template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q1_0, 8, dequantize_q1_0>;
|
||||
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
|
||||
@@ -9838,6 +10022,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
|
||||
@@ -9861,6 +10046,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
|
||||
|
||||
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
||||
@@ -9893,6 +10079,7 @@ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_m
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_id_q1_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
|
||||
@@ -9916,6 +10103,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q1_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
||||
@@ -10070,6 +10258,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_q1_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;
|
||||
|
||||
@@ -589,6 +589,7 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) {
|
||||
ggml_backend_buffer_free(opt_ctx->buf_cpu);
|
||||
ggml_free(opt_ctx->ctx_static);
|
||||
ggml_free(opt_ctx->ctx_cpu);
|
||||
ggml_free(opt_ctx->ctx_copy);
|
||||
delete opt_ctx;
|
||||
}
|
||||
|
||||
|
||||
@@ -44,6 +44,10 @@ void ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_sycl_flash_attn_ext_tile_case<256, 256>(ctx, dst);
|
||||
} break;
|
||||
case 512: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_sycl_flash_attn_ext_tile_case<512, 512>(ctx, dst);
|
||||
} break;
|
||||
case 576: {
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
ggml_sycl_flash_attn_ext_tile_case<576, 512>(ctx, dst);
|
||||
|
||||
@@ -67,6 +67,12 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, co
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
@@ -124,6 +130,12 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, co
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 2, 128, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
||||
@@ -131,134 +143,6 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, co
|
||||
return 0;
|
||||
}
|
||||
|
||||
static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) {
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 3, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 2, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 128, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 256, 2, 128, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) {
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 8, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 64, 8, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 5, 128, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 5, 128, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 8, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 8, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 8, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static constexpr uint32_t ggml_sycl_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
|
||||
if(fast_fp16_available(cc))
|
||||
return ggml_sycl_fattn_tile_get_config_fp16(DKQ, DV, ncols);
|
||||
@@ -1293,6 +1177,16 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_sycl_context & ctx, ggm
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
// ncols2=2 and ncols2=1 fallbacks only for cases where ncols=2 config exists (DKQ == DV).
|
||||
// For DKQ == 576, DV == 512 only GQA-optimized variants are implemented.
|
||||
if constexpr (DKQ == DV) {
|
||||
if (use_gqa_opt && gqa_ratio % 2 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (DV <= 256) {
|
||||
@@ -1347,5 +1241,6 @@ extern DECL_FATTN_TILE_CASE( 96, 96);
|
||||
extern DECL_FATTN_TILE_CASE(112, 112);
|
||||
extern DECL_FATTN_TILE_CASE(128, 128);
|
||||
extern DECL_FATTN_TILE_CASE(256, 256);
|
||||
extern DECL_FATTN_TILE_CASE(512, 512);
|
||||
extern DECL_FATTN_TILE_CASE(576, 512);
|
||||
|
||||
|
||||
@@ -664,4 +664,11 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
|
||||
|
||||
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_F16)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q4_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q4_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q5_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q5_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q8_0)
|
||||
|
||||
#endif // GGML_SYCL_FATTN_VEC_HPP
|
||||
|
||||
@@ -34,6 +34,7 @@
|
||||
FATTN_VEC_CASE( 64, type_K, type_V) \
|
||||
FATTN_VEC_CASE(128, type_K, type_V) \
|
||||
FATTN_VEC_CASE(256, type_K, type_V) \
|
||||
FATTN_VEC_CASE(512, type_K, type_V) \
|
||||
|
||||
static void ggml_sycl_flash_attn_ext_vec(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * Q = dst->src[0];
|
||||
@@ -141,6 +142,7 @@ static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const
|
||||
case 128:
|
||||
case 112:
|
||||
case 256:
|
||||
case 512:
|
||||
if (V->ne[0] != K->ne[0]) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
@@ -185,7 +187,7 @@ static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const
|
||||
}
|
||||
|
||||
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 512 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
// Todo: Use the XMX kernel if possible:
|
||||
|
||||
|
||||
@@ -411,11 +411,22 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
|
||||
assert(tensor->view_src->buffer->buft == buffer->buft);
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&
|
||||
!g_ggml_sycl_disable_optimize) {
|
||||
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
|
||||
tensor->extra = extra;
|
||||
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
|
||||
|
||||
if (!g_ggml_sycl_disable_optimize) {
|
||||
// set reorder extra buffer based on supported type
|
||||
switch (tensor->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q6_K:{
|
||||
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
|
||||
tensor->extra = extra;
|
||||
ctx->tensor_extras.push_back(extra);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (ggml_is_quantized(tensor->type)) {
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.hpp"
|
||||
|
||||
DECL_FATTN_TILE_CASE(512, 512);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_F16);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_F16);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_F16);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_F16);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
#if defined(A_TYPE_VEC4)
|
||||
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
|
||||
#if defined(A_TYPEV4)
|
||||
layout (binding = 0) readonly buffer AV4 {A_TYPEV4 data_a_v4[];};
|
||||
#endif
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
|
||||
@@ -17,11 +17,11 @@ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32
|
||||
#endif
|
||||
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
#ifdef B_TYPE_VEC2
|
||||
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
|
||||
#ifdef B_TYPEV2
|
||||
layout (binding = 1) readonly buffer BV2 {B_TYPEV2 data_b_v2[];};
|
||||
#endif
|
||||
#ifdef B_TYPE_VEC4
|
||||
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
||||
#ifdef B_TYPEV4
|
||||
layout (binding = 1) readonly buffer BV4 {B_TYPEV4 data_b_v4[];};
|
||||
#endif
|
||||
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
@@ -41,7 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
|
||||
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
|
||||
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
|
||||
|
||||
const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm);
|
||||
const FLOAT_TYPEV2 dm = vec2(data_a[ib0 + i].dm);
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
|
||||
|
||||
@@ -14,7 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
|
||||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
|
||||
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
|
||||
const FLOAT_TYPEV2 dm = FLOAT_TYPEV2(data_a[ib0 + i].dm);
|
||||
|
||||
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
||||
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
||||
|
||||
@@ -14,7 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
|
||||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
|
||||
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
|
||||
const FLOAT_TYPEV2 dm = FLOAT_TYPEV2(data_a[ib0 + i].dm);
|
||||
|
||||
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
||||
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
||||
|
||||
@@ -11,8 +11,8 @@ FLOAT_TYPE get_dm(uint ib) {
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
||||
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||
FLOAT_TYPEV2 get_dm(uint ib) {
|
||||
return FLOAT_TYPEV2(data_a_packed32[ib].dm);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -23,9 +23,9 @@ FLOAT_TYPE get_dm(uint ib) {
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q2_K)
|
||||
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
||||
FLOAT_TYPEV2 get_dm(uint ib) {
|
||||
const uint ib_k = ib / 8;
|
||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
|
||||
return FLOAT_TYPEV2(data_a_packed32[ib_k].dm);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -304,7 +304,7 @@ vec2 get_dm_scale(uint ib, uint iqs) {
|
||||
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
|
||||
}
|
||||
|
||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
|
||||
return FLOAT_TYPEV2(data_a_packed32[ib_k].dm) * FLOAT_TYPEV2(scale_dm);
|
||||
}
|
||||
|
||||
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
|
||||
@@ -422,7 +422,7 @@ vec2 get_dm(uint ib, uint iqs) {
|
||||
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
|
||||
|
||||
// the -1 cancels out the bias in iq1s_grid_gpu
|
||||
return FLOAT_TYPE_VEC2(dl, dl * (delta - 1));
|
||||
return FLOAT_TYPEV2(dl, dl * (delta - 1));
|
||||
}
|
||||
|
||||
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
|
||||
|
||||
@@ -125,8 +125,8 @@ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working wit
|
||||
#define SHMEM_STRIDE (BK / 2 + 1)
|
||||
#endif
|
||||
|
||||
shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE];
|
||||
shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
|
||||
shared FLOAT_TYPEV2 buf_a[BM * SHMEM_STRIDE];
|
||||
shared FLOAT_TYPEV2 buf_b[BN * SHMEM_STRIDE];
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
|
||||
@@ -258,17 +258,17 @@ void main() {
|
||||
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
|
||||
}
|
||||
#else
|
||||
ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
|
||||
ACC_TYPEV2 sums[WMITER * TM * WNITER * TN/2];
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||
FLOAT_TYPE_VEC4 cache_a[WMITER * TM];
|
||||
FLOAT_TYPE_VEC4 cache_b;
|
||||
FLOAT_TYPEV4 cache_a[WMITER * TM];
|
||||
FLOAT_TYPEV4 cache_b;
|
||||
#else
|
||||
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
|
||||
FLOAT_TYPE_VEC2 cache_b;
|
||||
FLOAT_TYPEV2 cache_a[WMITER * TM];
|
||||
FLOAT_TYPEV2 cache_b;
|
||||
#endif
|
||||
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
|
||||
sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
|
||||
sums[i] = ACC_TYPEV2(0.0f, 0.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
#if LOAD_VEC_A == 8
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPE_VEC8 aa = FLOAT_TYPE_VEC8(data_a[idx]);
|
||||
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa[0].xy;
|
||||
buf_a[buf_idx + 1] = aa[0].zw;
|
||||
buf_a[buf_idx + 2] = aa[1].xy;
|
||||
@@ -11,38 +11,38 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
#elif LOAD_VEC_A == 4
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]);
|
||||
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
#else // LOAD_VEC_BATCH_A == 2
|
||||
const uint idx = pos_a + col * p.stride_a + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx],
|
||||
data_a[idx + 1]);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx],
|
||||
data_a[idx + 1]);
|
||||
} else if (idx_m < p.M && block + row * 2 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], 0.0f);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], 0.0f);
|
||||
} else {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
#elif defined(DATA_A_BF16)
|
||||
#if LOAD_VEC_A == 4
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx]));
|
||||
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
#else // LOAD_VEC_BATCH_A == 2
|
||||
const uint idx = pos_a + col * p.stride_a + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]),
|
||||
TO_FLOAT_TYPE(data_a[idx + 1]));
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]),
|
||||
TO_FLOAT_TYPE(data_a[idx + 1]));
|
||||
} else if (idx_m < p.M && block + row * 2 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);
|
||||
} else {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
#elif defined(DATA_A_Q4_0)
|
||||
@@ -57,10 +57,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
|
||||
const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy);
|
||||
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPEV2(v0.zw);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPEV2(v1.xy);
|
||||
buf_a[buf_idx + 9] = FLOAT_TYPEV2(v1.zw);
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
@@ -73,10 +73,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y;
|
||||
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw);
|
||||
buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v1.xy);
|
||||
buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw);
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xy);
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPEV2(v0.zw);
|
||||
buf_a[buf_idx + 8 ] = FLOAT_TYPEV2(v1.xy);
|
||||
buf_a[buf_idx + 9 ] = FLOAT_TYPEV2(v1.zw);
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
@@ -92,8 +92,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
||||
const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(v.xz);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPEV2(v.yw);
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
@@ -112,10 +112,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec4 v0 = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * dm.x + dm.y;
|
||||
const vec4 v1 = vec4(((vui >> 16) & 0xF) | qh2.x, ((vui >> 20) & 0xF) | qh2.y, ((vui >> 24) & 0xF) | qh3.x, ((vui >> 28) & 0xF) | qh3.y) * dm.x + dm.y;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xz);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v1.xz);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v0.yw);
|
||||
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.yw);
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xz);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPEV2(v1.xz);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPEV2(v0.yw);
|
||||
buf_a[buf_idx + 9] = FLOAT_TYPEV2(v1.yw);
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -128,8 +128,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
|
||||
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw);
|
||||
#elif defined(DATA_A_Q2_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -147,8 +147,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
|
||||
const vec4 v = dm.x * float(scales & 0xF) * qs - dm.y * float(scales >> 4);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw);
|
||||
#elif defined(DATA_A_Q3_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -171,8 +171,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec2 qs = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> qsshift) & 0x0303).xy);
|
||||
const vec2 hm = vec2(unpack8(((uint(data_a_packed16[ib].hmask[hmi / 2]) >> (4 * n + halfsplit)) & 0x0101 ^ 0x0101) << 2).xy);
|
||||
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * (qs.x - hm.x),
|
||||
dl * (qs.y - hm.y));
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(dl * (qs.x - hm.x),
|
||||
dl * (qs.y - hm.y));
|
||||
#elif defined(DATA_A_Q4_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -206,8 +206,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
|
||||
const vec4 q = vec4(unpack8((data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(fma(d, q.x, m), fma(d, q.y, m));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPEV2(fma(d, q.z, m), fma(d, q.w, m));
|
||||
#elif defined(DATA_A_Q5_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -244,8 +244,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint qh = ((data_a_packed32[ib].qh[qhi / 4] >> (iqs / 16)) & 0x01010101) << 4;
|
||||
const vec4 q = vec4(unpack8(qs | qh));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(fma(d, q.x, m), fma(d, q.y, m));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPEV2(fma(d, q.z, m), fma(d, q.w, m));
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -267,7 +267,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint qh = (uint(data_a_packed16[ib].qh[qhi]) >> qhshift) & 0x0303;
|
||||
const vec2 q = (vec2(unpack8(ql | (qh << 4)).xy) - 32) * dscale;
|
||||
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(q.x, q.y);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(q.x, q.y);
|
||||
#elif defined(DATA_A_IQ1_S)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -284,8 +284,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
|
||||
|
||||
[[unroll]] for (int k = 0; k < 4; ++k) {
|
||||
buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
|
||||
dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
|
||||
buf_a[buf_idx + k] = FLOAT_TYPEV2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
|
||||
dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
|
||||
}
|
||||
#elif defined(DATA_A_IQ1_M)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
@@ -306,8 +306,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||
|
||||
[[unroll]] for (int k = 0; k < 4; ++k) {
|
||||
buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
|
||||
dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
|
||||
buf_a[buf_idx + k] = FLOAT_TYPEV2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
|
||||
dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
|
||||
}
|
||||
#elif defined(DATA_A_IQ2_XXS)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
@@ -332,14 +332,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
||||
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
||||
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
||||
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
||||
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
||||
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
||||
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
||||
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
||||
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ2_XS)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -358,14 +358,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
||||
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
||||
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
||||
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
||||
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
||||
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
||||
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
||||
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
||||
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ2_S)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -386,14 +386,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
||||
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
||||
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
||||
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
||||
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
||||
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
||||
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
||||
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
||||
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ3_XXS)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -414,10 +414,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint grid = iq3xxs_grid[qs];
|
||||
const vec4 v = db * vec4(unpack8(grid));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x,
|
||||
(sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z,
|
||||
(sign & 8) != 0 ? -v.w : v.w);
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2((sign & 1) != 0 ? -v.x : v.x,
|
||||
(sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPEV2((sign & 4) != 0 ? -v.z : v.z,
|
||||
(sign & 8) != 0 ? -v.w : v.w);
|
||||
#elif defined(DATA_A_IQ3_S)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -436,10 +436,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
|
||||
const vec4 v = db * vec4(unpack8(grid));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x,
|
||||
(sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z,
|
||||
(sign & 8) != 0 ? -v.w : v.w);
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2((sign & 1) != 0 ? -v.x : v.x,
|
||||
(sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPEV2((sign & 4) != 0 ? -v.z : v.z,
|
||||
(sign & 8) != 0 ? -v.w : v.w);
|
||||
#elif defined(DATA_A_IQ4_XS)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -456,8 +456,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const float d = float(data_a[ib].d);
|
||||
const vec4 v = d * float(int(sl | (sh << 4)) - 32) * vec4(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y], kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw);
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
@@ -468,10 +468,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
|
||||
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
||||
|
||||
buf_a[buf_idx ] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[vui & 0xF],
|
||||
kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]);
|
||||
buf_a[buf_idx + 8] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)],
|
||||
kvalues_iq4nl[vui >> 12]);
|
||||
buf_a[buf_idx ] = d * FLOAT_TYPEV2(kvalues_iq4nl[vui & 0xF],
|
||||
kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]);
|
||||
buf_a[buf_idx + 8] = d * FLOAT_TYPEV2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)],
|
||||
kvalues_iq4nl[vui >> 12]);
|
||||
#elif defined(DATA_A_MXFP4)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
@@ -483,10 +483,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui & 0xF] * d,
|
||||
kvalues_mxfp4[vui2 & 0xF] * d);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui >> 4] * d,
|
||||
kvalues_mxfp4[vui2 >> 4] * d);
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d,
|
||||
kvalues_mxfp4[vui2 & 0xF] * d);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d,
|
||||
kvalues_mxfp4[vui2 >> 4] * d);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -496,7 +496,7 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]);
|
||||
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
@@ -505,9 +505,9 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
#if defined(DATA_B_BF16)
|
||||
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
#else
|
||||
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]);
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
@@ -515,12 +515,12 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
|
||||
const uint idx = pos_b + col * p.stride_b + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_n < p.N && block + row * 2 + 1 < end_k) {
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
|
||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
|
||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||
} else if (idx_n < p.N && block + row * 2 < end_k) {
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||
} else {
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -531,7 +531,7 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]);
|
||||
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
@@ -541,9 +541,9 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
#if defined(DATA_B_BF16)
|
||||
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
#else
|
||||
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]);
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
@@ -553,14 +553,14 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
|
||||
if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
|
||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
|
||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||
} else if (row_i < _ne1 && block + row * 2 < end_k) {
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||
} else {
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||
buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib].dm);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -72,7 +72,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||
buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib].dm);
|
||||
buf_a[buf_ib].qh = data_a_packed32[ib].qh;
|
||||
}
|
||||
#endif
|
||||
@@ -203,7 +203,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
|
||||
buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib_k].dm);
|
||||
buf_a[buf_ib].scales = unpack8(uint32_t(data_a_packed16[ib_k].scales[iqs_k / 8])).xy; // vec4 used due to #12147
|
||||
}
|
||||
}
|
||||
@@ -264,7 +264,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const i8vec2 scales = i8vec2(unpack8(uint32_t(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
|
||||
(((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))).xy); // vec4 used due to #12147
|
||||
|
||||
buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales - 32));
|
||||
buf_a[buf_ib].d_scales = FLOAT_TYPEV2(float(data_a_packed16[ib_k].d) * vec2(scales - 32));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -334,7 +334,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
|
||||
}
|
||||
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(vec2(data_a_packed32[ib_k].dm) * vec2(scale_dm));
|
||||
buf_a[buf_ib].dm = FLOAT_TYPEV2(vec2(data_a_packed32[ib_k].dm) * vec2(scale_dm));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -385,7 +385,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint is = iqs_k / 4;
|
||||
const i8vec2 scales = unpack8(int32_t(data_a_packed16[ib_k].scales[is / 2])).xy;
|
||||
|
||||
buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales));
|
||||
buf_a[buf_ib].d_scales = FLOAT_TYPEV2(float(data_a_packed16[ib_k].d) * vec2(scales));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -426,7 +426,7 @@ void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bo
|
||||
const uint ib_inner = ib % 4;
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
|
||||
buf_b[buf_ib].ds = FLOAT_TYPEV2(data_b[ib_outer].ds[ib_inner]);
|
||||
}
|
||||
|
||||
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
|
||||
@@ -436,7 +436,7 @@ void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bo
|
||||
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
|
||||
} else {
|
||||
if (iqs == 0) {
|
||||
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f);
|
||||
buf_b[buf_ib].ds = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
|
||||
buf_b[buf_ib].qs[iqs * 4 ] = 0;
|
||||
|
||||
@@ -8,7 +8,7 @@ struct block_a_cache {
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[16/4];
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
FLOAT_TYPEV2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
#define QUANT_R_MMQ 2
|
||||
@@ -22,7 +22,7 @@ struct block_a_cache {
|
||||
struct block_a_cache {
|
||||
uint32_t qs[16/4];
|
||||
uint32_t qh;
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
FLOAT_TYPEV2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
#define QUANT_R_MMQ 1
|
||||
@@ -43,36 +43,36 @@ struct block_a_cache {
|
||||
struct block_a_cache {
|
||||
uint32_t qs[2];
|
||||
u8vec2 scales;
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
FLOAT_TYPEV2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q3_K)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[4];
|
||||
FLOAT_TYPE_VEC2 d_scales;
|
||||
FLOAT_TYPEV2 d_scales;
|
||||
};
|
||||
#elif defined(DATA_A_Q4_K)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[4];
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
FLOAT_TYPEV2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q5_K)
|
||||
#define QUANT_R_MMQ 1
|
||||
struct block_a_cache {
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
FLOAT_TYPEV2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
#define QUANT_R_MMQ 1
|
||||
struct block_a_cache {
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE_VEC2 d_scales;
|
||||
FLOAT_TYPEV2 d_scales;
|
||||
};
|
||||
#endif
|
||||
|
||||
struct block_b_cache
|
||||
{
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE_VEC2 ds;
|
||||
FLOAT_TYPEV2 ds;
|
||||
};
|
||||
|
||||
@@ -446,8 +446,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
base_dict["FLOAT16"] = "1";
|
||||
}
|
||||
|
||||
base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float";
|
||||
base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2";
|
||||
base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float";
|
||||
base_dict["ACC_TYPEV2"] = f16acc ? "f16vec2" : "vec2";
|
||||
if (f16acc) {
|
||||
base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
|
||||
}
|
||||
@@ -514,10 +514,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
};
|
||||
|
||||
const std::map<std::string, std::string> float_type_dict_f16 = {
|
||||
{"FLOAT_TYPE", FLOAT_TYPE(1, "f16")},
|
||||
{"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "f16")},
|
||||
{"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "f16")},
|
||||
{"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, "f16")},
|
||||
{"FLOAT_TYPE", FLOAT_TYPE(1, "f16")},
|
||||
{"FLOAT_TYPEV2", FLOAT_TYPE(2, "f16")},
|
||||
{"FLOAT_TYPEV4", FLOAT_TYPE(4, "f16")},
|
||||
{"FLOAT_TYPEV8", FLOAT_TYPE(8, "f16")},
|
||||
};
|
||||
|
||||
// Shaders with f16 B_TYPE
|
||||
@@ -536,9 +536,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
|
||||
|
||||
const std::map<std::string, std::string> float_type_dict_bf16 = {
|
||||
{"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")},
|
||||
{"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "bf16")},
|
||||
{"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "bf16")},
|
||||
{"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")},
|
||||
{"FLOAT_TYPEV2", FLOAT_TYPE(2, "bf16")},
|
||||
{"FLOAT_TYPEV4", FLOAT_TYPE(4, "bf16")},
|
||||
};
|
||||
|
||||
// If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
|
||||
@@ -569,10 +569,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
|
||||
|
||||
const std::map<std::string, std::string> float_type_dict = {
|
||||
{"FLOAT_TYPE", FLOAT_TYPE(1, tname)},
|
||||
{"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, tname)},
|
||||
{"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, tname)},
|
||||
{"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, tname)},
|
||||
{"FLOAT_TYPE", FLOAT_TYPE(1, tname)},
|
||||
{"FLOAT_TYPEV2", FLOAT_TYPE(2, tname)},
|
||||
{"FLOAT_TYPEV4", FLOAT_TYPE(4, tname)},
|
||||
{"FLOAT_TYPEV8", FLOAT_TYPE(8, tname)},
|
||||
};
|
||||
|
||||
// don't generate f32 variants for coopmat2
|
||||
@@ -676,36 +676,36 @@ void process_shaders() {
|
||||
}
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
|
||||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}};
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
// mul mat vec
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
|
||||
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
|
||||
// mul mat vec with integer dot product
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname) || tname == "iq1_s" || tname == "iq1_m") {
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -726,9 +726,9 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("get_rows_i32", "get_rows.comp", {{"TEMP_TYPE", "uint"}, {"A_TYPE", "uint"}, {"B_TYPE", "int"}, {"D_TYPE", "uint"}});
|
||||
|
||||
string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
|
||||
string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
|
||||
string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}});
|
||||
|
||||
// Norms
|
||||
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
@@ -4033,8 +4033,14 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
|
||||
|
||||
static ggml_backend_webgpu_reg_context ctx;
|
||||
static ggml_backend_reg reg = {
|
||||
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
||||
/* .iface = */ ggml_backend_webgpu_reg_i,
|
||||
/* .context = */ &ctx,
|
||||
};
|
||||
|
||||
ctx.name = GGML_WEBGPU_NAME;
|
||||
ctx.device_count = 1;
|
||||
ctx.device_count = 0;
|
||||
|
||||
wgpu::InstanceDescriptor instance_descriptor{};
|
||||
std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
|
||||
@@ -4053,19 +4059,28 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
||||
ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct());
|
||||
ctx.webgpu_global_ctx->instance = std::move(inst);
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
if (ctx.webgpu_global_ctx->instance == nullptr) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n");
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
GGML_ASSERT(ctx.webgpu_global_ctx->instance != nullptr);
|
||||
wgpu::Adapter adapter;
|
||||
if (ctx.webgpu_global_ctx->instance != nullptr) {
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
|
||||
// probe for adapter support
|
||||
ctx.webgpu_global_ctx->instance.WaitAny(
|
||||
ctx.webgpu_global_ctx->instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) {
|
||||
if (status != wgpu::RequestAdapterStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
||||
return;
|
||||
}
|
||||
adapter = std::move(_adapter);
|
||||
}),
|
||||
UINT64_MAX);
|
||||
}
|
||||
|
||||
if (adapter != nullptr) {
|
||||
ctx.device_count = 1;
|
||||
}
|
||||
|
||||
static ggml_backend_reg reg = {
|
||||
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
||||
/* .iface = */ ggml_backend_webgpu_reg_i,
|
||||
/* .context = */ &ctx,
|
||||
};
|
||||
return ®
|
||||
}
|
||||
|
||||
|
||||
@@ -4122,6 +4122,7 @@ class VisionProjectorType:
|
||||
LIGHTONOCR = "lightonocr"
|
||||
COGVLM = "cogvlm"
|
||||
JANUS_PRO = "janus_pro"
|
||||
DOTSOCR = "dots_ocr"
|
||||
DEEPSEEKOCR = "deepseekocr"
|
||||
LFM2A = "lfm2a" # audio
|
||||
MUSIC_FLAMINGO = "musicflamingo" # audio
|
||||
|
||||
@@ -1359,6 +1359,7 @@ class TensorNameMap:
|
||||
"visual.merger.mlp.{bid}", # qwen2vl
|
||||
"mlp_AR.linear_{bid}", # PaddleOCR-VL
|
||||
"merger.mlp.{bid}",
|
||||
"vision_tower.merger.mlp.{bid}", # dots.ocr
|
||||
"vit.perceive.proj.{bid}", # HunyuanOCR (proj.0 = conv1, proj.2 = conv2)
|
||||
),
|
||||
|
||||
@@ -1406,11 +1407,13 @@ class TensorNameMap:
|
||||
"siglip2.vision_model.embeddings.patch_embedding",
|
||||
"vision_model.radio_model.model.patch_generator.embedder", # Nemotron Nano v2 VL
|
||||
"model.vision_tower.patch_embedder.input_proj", # gemma4
|
||||
"vision_tower.patch_embed.patchifier.proj", # dots.ocr
|
||||
"vision_model.conv1", # Step3-VL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_NORM: (
|
||||
"visual.post_conv_layernorm", # glm4v
|
||||
"vision_tower.patch_embed.patchifier.norm", # dots.ocr
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_POS: (
|
||||
@@ -1441,6 +1444,7 @@ class TensorNameMap:
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_QKV: (
|
||||
"visual.blocks.{bid}.attn.qkv", # qwen3vl
|
||||
"vision_tower.blocks.{bid}.attn.qkv", # dots.ocr
|
||||
"model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm
|
||||
"model.vision_model.transformer.layers.{bid}.self_attn.qkv_proj", # Deepseek-OCR CLIP
|
||||
"vision_tower.encoder.blocks.{bid}.wqkv", # Kimi-K2.5
|
||||
@@ -1526,6 +1530,7 @@ class TensorNameMap:
|
||||
"model.vision_model.transformer.layers.{bid}.layer_norm1", # Deepseek-OCR CLIP
|
||||
"siglip2.vision_model.encoder.layers.{bid}.layer_norm1",
|
||||
"vision_model.radio_model.model.blocks.{bid}.norm1", # Nemotron Nano v2 VL
|
||||
"vision_tower.blocks.{bid}.norm1", # dots.ocr
|
||||
"vision_model.transformer.resblocks.{bid}.ln_1", # Step3-VL
|
||||
),
|
||||
|
||||
@@ -1547,6 +1552,7 @@ class TensorNameMap:
|
||||
"siglip2.vision_model.encoder.layers.{bid}.self_attn.out_proj", # youtuvl
|
||||
"vision_model.radio_model.model.blocks.{bid}.attn.proj", # Nemotron Nano v2 VL
|
||||
"vision_model.model.layers.{bid}.self_attn.o_proj.linear", # gemma4
|
||||
"vision_tower.blocks.{bid}.attn.proj", # dots.ocr
|
||||
"vision_model.transformer.resblocks.{bid}.attn.out_proj", # Step3-VL
|
||||
),
|
||||
|
||||
@@ -1567,6 +1573,7 @@ class TensorNameMap:
|
||||
"siglip2.vision_model.encoder.layers.{bid}.layer_norm2",
|
||||
"vision_model.radio_model.model.blocks.{bid}.norm2", # Nemotron Nano v2 VL
|
||||
"vision_model.model.layers.{bid}.pre_feedforward_layernorm", # gemma4
|
||||
"vision_tower.blocks.{bid}.norm2", # dots.ocr
|
||||
"vision_model.transformer.resblocks.{bid}.ln_2", # Step3-VL
|
||||
),
|
||||
|
||||
@@ -1649,6 +1656,7 @@ class TensorNameMap:
|
||||
"vision_encoder.ln_pre", # pixtral
|
||||
"vision_model.layernorm_pre", # llama4
|
||||
"model.vision_model.pre_layrnorm", # Deepseek-OCR CLIP
|
||||
"vision_tower.patch_embed.patchifier.norm", # dots.ocr
|
||||
"vision_model.ln_pre", # Step3-VL
|
||||
),
|
||||
|
||||
@@ -1664,6 +1672,7 @@ class TensorNameMap:
|
||||
|
||||
MODEL_TENSOR.V_MM_POST_NORM: (
|
||||
"visual.merger.post_projection_norm", # glm4v
|
||||
"vision_tower.post_trunk_norm", # dots.ocr
|
||||
"vit.perceive.after_rms", # HunyuanOCR
|
||||
),
|
||||
|
||||
@@ -1680,6 +1689,7 @@ class TensorNameMap:
|
||||
"model.vision.linear_proj.norm1", # cogvlm
|
||||
"mlp_AR.pre_norm", # PaddleOCR-VL
|
||||
"merger.ln_q",
|
||||
"vision_tower.merger.ln_q", # dots.ocr
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
|
||||
|
||||
@@ -543,7 +543,7 @@ class LlamaHfVocab(Vocab):
|
||||
cache_dir=base_path,
|
||||
local_files_only=True,
|
||||
)
|
||||
assert self.tokenizer.is_fast # assume tokenizer.json is used
|
||||
assert self.tokenizer.is_fast # assume tokenizer.json is used # ty: ignore[unresolved-attribute]
|
||||
|
||||
# Initialize lists and dictionaries for added tokens
|
||||
self.added_tokens_list = []
|
||||
@@ -552,30 +552,30 @@ class LlamaHfVocab(Vocab):
|
||||
|
||||
# Process added tokens
|
||||
for tok, tokidx in sorted(
|
||||
self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
|
||||
self.tokenizer.get_added_vocab().items(), key=lambda x: x[1] # ty: ignore[unresolved-attribute]
|
||||
):
|
||||
# Only consider added tokens that are not in the base vocabulary
|
||||
if tokidx >= self.tokenizer.vocab_size:
|
||||
if tokidx >= self.tokenizer.vocab_size: # ty: ignore[unresolved-attribute]
|
||||
self.added_tokens_list.append(tok)
|
||||
self.added_tokens_dict[tok] = tokidx
|
||||
self.added_tokens_ids.add(tokidx)
|
||||
|
||||
# Store special tokens and their IDs
|
||||
self.specials = {
|
||||
tok: self.tokenizer.get_vocab()[tok]
|
||||
for tok in self.tokenizer.all_special_tokens
|
||||
tok: self.tokenizer.get_vocab()[tok] # ty: ignore[unresolved-attribute]
|
||||
for tok in self.tokenizer.all_special_tokens # ty: ignore[unresolved-attribute]
|
||||
}
|
||||
self.special_ids = set(self.tokenizer.all_special_ids)
|
||||
self.special_ids = set(self.tokenizer.all_special_ids) # ty: ignore[unresolved-attribute]
|
||||
|
||||
# Set vocabulary sizes
|
||||
self.vocab_size_base = self.tokenizer.vocab_size
|
||||
self.vocab_size_base = self.tokenizer.vocab_size # ty: ignore[unresolved-attribute]
|
||||
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
|
||||
|
||||
self.fname_tokenizer = fname_tokenizer
|
||||
|
||||
def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||
reverse_vocab = {
|
||||
id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
|
||||
id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items() # ty: ignore[unresolved-attribute]
|
||||
}
|
||||
|
||||
for token_id in range(self.vocab_size_base):
|
||||
@@ -616,7 +616,7 @@ class LlamaHfVocab(Vocab):
|
||||
yield text.encode("utf-8"), score, toktype
|
||||
|
||||
def has_newline_token(self):
|
||||
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
|
||||
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab # ty: ignore[unresolved-attribute]
|
||||
|
||||
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||
yield from self.hf_tokens()
|
||||
|
||||
BIN
models/ggml-vocab-gemma-4.gguf
Normal file
BIN
models/ggml-vocab-gemma-4.gguf
Normal file
Binary file not shown.
111
models/ggml-vocab-gemma-4.gguf.inp
Normal file
111
models/ggml-vocab-gemma-4.gguf.inp
Normal file
@@ -0,0 +1,111 @@
|
||||
ied 4 ½ months
|
||||
__ggml_vocab_test__
|
||||
Äpfel
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
Hello world
|
||||
__ggml_vocab_test__
|
||||
Hello world
|
||||
__ggml_vocab_test__
|
||||
Hello World
|
||||
__ggml_vocab_test__
|
||||
Hello World
|
||||
__ggml_vocab_test__
|
||||
Hello World!
|
||||
__ggml_vocab_test__
|
||||
Hello, world!
|
||||
__ggml_vocab_test__
|
||||
Hello, world!
|
||||
__ggml_vocab_test__
|
||||
this is 🦙.cpp
|
||||
__ggml_vocab_test__
|
||||
w048 7tuijk dsdfhu
|
||||
__ggml_vocab_test__
|
||||
нещо на Български
|
||||
__ggml_vocab_test__
|
||||
កាន់តែពិសេសអាចខលចេញ
|
||||
__ggml_vocab_test__
|
||||
🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
(
|
||||
__ggml_vocab_test__
|
||||
|
||||
=
|
||||
__ggml_vocab_test__
|
||||
' era
|
||||
__ggml_vocab_test__
|
||||
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
|
||||
__ggml_vocab_test__
|
||||
!!!!!!
|
||||
__ggml_vocab_test__
|
||||
3
|
||||
__ggml_vocab_test__
|
||||
33
|
||||
__ggml_vocab_test__
|
||||
333
|
||||
__ggml_vocab_test__
|
||||
3333
|
||||
__ggml_vocab_test__
|
||||
33333
|
||||
__ggml_vocab_test__
|
||||
333333
|
||||
__ggml_vocab_test__
|
||||
3333333
|
||||
__ggml_vocab_test__
|
||||
33333333
|
||||
__ggml_vocab_test__
|
||||
333333333
|
||||
__ggml_vocab_test__
|
||||
Cửa Việt
|
||||
__ggml_vocab_test__
|
||||
discards
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
|
||||
46
models/ggml-vocab-gemma-4.gguf.out
Normal file
46
models/ggml-vocab-gemma-4.gguf.out
Normal file
@@ -0,0 +1,46 @@
|
||||
1178 236743 236812 47041 3794
|
||||
239122 22744 535
|
||||
|
||||
236743
|
||||
138
|
||||
139
|
||||
255968
|
||||
107
|
||||
108
|
||||
109
|
||||
255968 107
|
||||
9259 1902
|
||||
26352 1902
|
||||
9259 4109
|
||||
26352 4109
|
||||
26352 4109 236888
|
||||
9259 236764 1902 236888
|
||||
26352 236764 1902 236888
|
||||
672 563 236743 478 397 404 391 236761 12362
|
||||
236765 236771 236812 236828 236743 236832 11372 12065 31806 3405 9360
|
||||
1337 12515 1333 4632 165543 3830
|
||||
234889 63031 219876 66212 239077 237907 144494
|
||||
242015 568 7382 236768 236743 247717 237243 248989 238178 568 43819 111730 150567 236768 113452 568 8960 64334 600 815 1061 1852 8369 236768
|
||||
9259
|
||||
26352
|
||||
138 9259
|
||||
139 9259
|
||||
140 9259
|
||||
140 9259 107 140 9259
|
||||
568
|
||||
107 578
|
||||
236789 6933
|
||||
9259 236764 570 236789 712 236888 2088 659 611 170124 2360 62133 237075 17641 11700 236770 236800 236770 236812 236770 236810 236770 237471 238352
|
||||
123947
|
||||
236800
|
||||
236800 236800
|
||||
236800 236800 236800
|
||||
236800 236800 236800 236800
|
||||
236800 236800 236800 236800 236800
|
||||
236800 236800 236800 236800 236800 236800
|
||||
236800 236800 236800 236800 236800 236800 236800
|
||||
236800 236800 236800 236800 236800 236800 236800 236800
|
||||
236800 236800 236800 236800 236800 236800 236800 236800 236800
|
||||
236780 29719 33154
|
||||
2243 2206
|
||||
107 236743 108 236743 109 236743 255968 236743 255969 236743 255968 107 138 107 139 107 140 107 141 107 242015 568 7382 236768 236743 247717 237243 248989 238178 568 43819 111730 150567 236768 113452 236743 478 397 404 391 478 397 404 391 236743 236800 236743 236800 236800 236743 236800 236800 236800 236743 236800 236800 236800 236800 236743 236800 236800 236800 236800 236800 236743 236800 236800 236800 236800 236800 236800 236743 236800 236800 236800 236800 236800 236800 236800 236743 236800 236800 236800 236800 236800 236800 236800 236800 236743 236800 236761 236800 236743 236800 856 236800 236743 236800 1390 236800 90986 92814 63031 219876 66212 241702 2360 62133 237075 17641 11700 236770 236800 236770 236812 236770 236810 236770 237471 238352 80448 120697 210119 1333 4632 165543 3830 9451 159561 2629 2629 2717 84491 19938 123947 38950 10371 564 236789 560 1010 756 151812 668 236789 236751 993 236764 756 1357 611 2889 236881 756 236792 711 2889 564 236789 859 1386 625 236764 756 236796 611 1133 1070 11115 236881 1191 236789 32541 496 236789 95635
|
||||
@@ -18,7 +18,7 @@ classifiers = [
|
||||
python = ">=3.9"
|
||||
numpy = "^1.25.0"
|
||||
sentencepiece = ">=0.1.98,<0.3.0"
|
||||
transformers = ">=4.35.2,<5.0.0"
|
||||
transformers = "==5.5.1"
|
||||
protobuf = ">=4.21.0,<5.0.0"
|
||||
gguf = { path = "./gguf-py" }
|
||||
torch = { version = "^2.2.0", source = "pytorch" }
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
numpy~=1.26.4
|
||||
sentencepiece>=0.1.98,<0.3.0
|
||||
|
||||
transformers>=4.57.1,<5.0.0
|
||||
transformers==5.5.1
|
||||
|
||||
gguf>=0.1.0
|
||||
protobuf>=4.21.0,<5.0.0
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
aiohttp~=3.9.3
|
||||
pytest~=8.3.3
|
||||
huggingface_hub>=0.34.0,<1.0
|
||||
huggingface_hub>=1.5.0,<2.0
|
||||
matplotlib~=3.10.0
|
||||
numpy~=1.26.4
|
||||
openai~=2.14.0
|
||||
|
||||
@@ -558,20 +558,20 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
|
||||
// example: https://github.com/ggml-org/llama.cpp/pull/17548
|
||||
//
|
||||
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // do the norms on the first layer (not the input layer)
|
||||
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
|
||||
{LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
|
||||
{LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_OUTPUT_NORM_LFM2, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
|
||||
{LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
|
||||
{LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_OUTPUT_NORM_LFM2, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ROPE_FREQS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
|
||||
{LLM_TENSOR_ROPE_FACTORS_LONG, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
|
||||
{LLM_TENSOR_ROPE_FACTORS_SHORT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
|
||||
@@ -708,9 +708,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
||||
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
// altup / laurel (gemma 3n)
|
||||
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
|
||||
@@ -2942,7 +2942,7 @@ llama_context * llama_init_from_model(
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
||||
}
|
||||
|
||||
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
|
||||
if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_k)) {
|
||||
const uint32_t blck_size = ggml_blck_size(params.type_k);
|
||||
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
|
||||
if (model->hparams.n_embd_head_k(il) % blck_size != 0) {
|
||||
@@ -2953,7 +2953,7 @@ llama_context * llama_init_from_model(
|
||||
}
|
||||
}
|
||||
|
||||
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
|
||||
if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_v)) {
|
||||
const uint32_t blck_size = ggml_blck_size(params.type_v);
|
||||
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
|
||||
if (model->hparams.n_embd_head_v(il) % blck_size != 0) {
|
||||
|
||||
@@ -4211,13 +4211,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
|
||||
altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
|
||||
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0);
|
||||
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_altup}, 0);
|
||||
altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
|
||||
altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
|
||||
|
||||
per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
|
||||
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_altup * n_layer}, 0);
|
||||
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_altup}, 0);
|
||||
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
|
||||
@@ -4276,9 +4277,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
if (n_embd_per_layer > 0) {
|
||||
tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0);
|
||||
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_per_layer * n_layer}, 0);
|
||||
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_per_layer}, 0);
|
||||
per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0);
|
||||
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_per_layer * n_layer}, 0);
|
||||
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_per_layer}, 0);
|
||||
}
|
||||
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
|
||||
@@ -534,9 +534,9 @@ struct llama_model {
|
||||
struct ggml_tensor * conv1d_b = nullptr;
|
||||
|
||||
// gemma3n altup
|
||||
struct ggml_tensor * tok_embd_per_layer = nullptr;
|
||||
struct ggml_tensor * altup_proj = nullptr;
|
||||
struct ggml_tensor * altup_unembd_proj = nullptr;
|
||||
struct ggml_tensor * per_layer_tok_embd = nullptr;
|
||||
struct ggml_tensor * per_layer_model_proj = nullptr;
|
||||
struct ggml_tensor * per_layer_proj_norm = nullptr;
|
||||
|
||||
|
||||
@@ -659,8 +659,17 @@ struct llm_tokenizer_bpe_session {
|
||||
|
||||
if (token == LLAMA_TOKEN_NULL) {
|
||||
for (auto j = str.begin(); j != str.end(); ++j) {
|
||||
std::string byte_str(1, *j);
|
||||
auto token_multibyte = vocab.text_to_token(byte_str);
|
||||
llama_token token_multibyte = LLAMA_TOKEN_NULL;
|
||||
if (tokenizer.byte_encode) {
|
||||
std::string byte_str(1, *j);
|
||||
token_multibyte = vocab.text_to_token(byte_str);
|
||||
} else {
|
||||
// For non-byte-encoded BPE (e.g. gemma-4), byte tokens use <0xXX> format
|
||||
static const char * hex = "0123456789ABCDEF";
|
||||
const uint8_t ch = (uint8_t)*j;
|
||||
const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
|
||||
token_multibyte = vocab.text_to_token(buf);
|
||||
}
|
||||
if (token_multibyte != LLAMA_TOKEN_NULL) {
|
||||
output.push_back(token_multibyte);
|
||||
}
|
||||
@@ -2558,7 +2567,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
|| t.first == "[EOS]" // Kimi-K2
|
||||
|| t.first == "<|end_of_text|>"
|
||||
|| t.first == "<end_of_utterance>" // smoldocling
|
||||
|| t.first == "<turn|>" // gemma4
|
||||
|| t.first == "<eos>" // gemma4
|
||||
|| t.first == "<turn|>" // gemma4
|
||||
|| t.first == "<|tool_response>" // gemma4
|
||||
|| t.first == "<|end▁of▁sentence|>" // deepseek-ocr
|
||||
) {
|
||||
@@ -2645,6 +2655,33 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>', or '<|calls|>' and '<|flush|>' tokens, removing '<|end|>' token from EOG list\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
// workaround for gemma4 and paddleocr: do not include </s> as an eog token
|
||||
{
|
||||
bool has_tool_response = false;
|
||||
bool has_s = false;
|
||||
|
||||
llama_token s_id = LLAMA_TOKEN_NULL;
|
||||
|
||||
for (auto tid : special_eog_ids) {
|
||||
const auto & text = id_to_token[tid].text;
|
||||
if (text == "<|tool_response>") {
|
||||
has_tool_response = true;
|
||||
} else if (text == "</s>") {
|
||||
has_s = true;
|
||||
s_id = tid;
|
||||
}
|
||||
}
|
||||
|
||||
if (has_tool_response && has_s) {
|
||||
special_eog_ids.erase(s_id);
|
||||
|
||||
auto & attr = id_to_token[s_id].attr;
|
||||
attr = LLAMA_TOKEN_ATTR_NORMAL;
|
||||
|
||||
LLAMA_LOG_WARN("%s: special_eog_ids contains '<|tool_response>', removing '</s>' token from EOG list\n", __func__);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// build special tokens cache
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
#include "models.h"
|
||||
|
||||
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
|
||||
static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) {
|
||||
GGML_ASSERT(idx < (int) x->ne[2]);
|
||||
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
|
||||
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
|
||||
}
|
||||
|
||||
llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params),
|
||||
model(model),
|
||||
@@ -22,8 +29,11 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
// TODO: is causal == true correct? might need some changes
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
|
||||
// inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
|
||||
ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
|
||||
ggml_tensor * inp_per_layer = build_inp_per_layer();
|
||||
ggml_build_forward_expand(gf, inp_per_layer);
|
||||
|
||||
// inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
|
||||
inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer);
|
||||
|
||||
// inpL now has only 1 altup, project it to the rest of the altups
|
||||
// these "added" altups will be concat to the last dim of inpL
|
||||
@@ -37,8 +47,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup]
|
||||
cb(inpL, "inp_stacked", -1);
|
||||
}
|
||||
// inpL now has shape: [n_embd, n_tokens, n_altup]
|
||||
// inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
|
||||
// inpL now has shape: [n_embd, n_tokens, n_altup]
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// this block is made to be closely resemble Gemma3p5DecoderLayer on python code
|
||||
@@ -49,8 +58,8 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup]
|
||||
|
||||
// predicted value will go through self-attention and laurel
|
||||
ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens]
|
||||
cur = active_prediction;
|
||||
ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act); // [n_embd, n_tokens]
|
||||
cur = active_prediction;
|
||||
cb(cur, "active_prediction", il);
|
||||
|
||||
// norm
|
||||
@@ -151,12 +160,13 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
|
||||
ggml_tensor * first_prediction; // [n_embd, n_tokens]
|
||||
{
|
||||
first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens]
|
||||
first_prediction = ggml_view_2d_slice(ctx0, corrected, i_altup_act); // [n_embd, n_tokens]
|
||||
first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale);
|
||||
first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction);
|
||||
first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens]
|
||||
cb(first_prediction, "first_prediction_gated", il);
|
||||
ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens]
|
||||
|
||||
ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_altup, n_tokens]
|
||||
first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens]
|
||||
cb(first_prediction, "first_prediction_scaled", il);
|
||||
|
||||
@@ -167,7 +177,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
}
|
||||
// equivalent to python code: corrected_predictions[1:] += first_prediction
|
||||
{
|
||||
ggml_tensor * slice_first = view_2d_slice(corrected, 0);
|
||||
ggml_tensor * slice_first = ggml_view_2d_slice(ctx0, corrected, 0);
|
||||
ggml_tensor * slice_rest = ggml_view_3d(
|
||||
ctx0, corrected, n_embd, n_tokens, n_altup - 1, ggml_row_size(corrected->type, n_embd),
|
||||
ggml_row_size(corrected->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(corrected));
|
||||
@@ -185,7 +195,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
|
||||
// cur now has multiple altup(s), we want to merge them back to 1 altup
|
||||
{
|
||||
ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens]
|
||||
ggml_tensor * target_magnitude = calc_magnitude(ggml_view_2d_slice(ctx0, cur, i_altup_act)); // [n_embd, n_tokens]
|
||||
// do a view to skip the first slice (active altup)
|
||||
ggml_tensor * alt_slice =
|
||||
ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1, ggml_row_size(cur->type, n_embd),
|
||||
@@ -197,9 +207,9 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
cb(altup_unembd, "altup_unembd", -1);
|
||||
|
||||
// equivalent to torch.mean(hidden_states, dim=0)
|
||||
cur = view_2d_slice(cur, 0); // [n_embd, n_tokens]
|
||||
cur = ggml_view_2d_slice(ctx0, cur, 0); // [n_embd, n_tokens]
|
||||
for (int i = 0; i < n_altup - 1; ++i) {
|
||||
cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i));
|
||||
cur = ggml_add(ctx0, cur, ggml_view_2d_slice(ctx0, altup_unembd, i));
|
||||
}
|
||||
cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens]
|
||||
cb(cur, "unembd_merged", -1);
|
||||
@@ -235,39 +245,34 @@ ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) {
|
||||
return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x)));
|
||||
}
|
||||
|
||||
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
|
||||
ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) {
|
||||
GGML_ASSERT(idx < (int) x->ne[2]);
|
||||
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
|
||||
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
|
||||
}
|
||||
|
||||
// equivalent to get_per_layer_inputs() in python code
|
||||
// output shape: [n_embd_altup, n_layer, n_tokens]
|
||||
ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
|
||||
ggml_tensor * llm_build_gemma3n_iswa::build_inp_per_layer() {
|
||||
auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
|
||||
ggml_tensor * inp_per_layer;
|
||||
float tok_embd_scale = sqrtf((float) n_embd_altup);
|
||||
if (ubatch.token) {
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
ggml_set_input(inp->tokens);
|
||||
res->t_inp_tokens = inp->tokens;
|
||||
inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
|
||||
inp_per_layer = ggml_get_rows (ctx0, model.per_layer_tok_embd, inp->tokens);
|
||||
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
|
||||
inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup));
|
||||
inp_per_layer = ggml_scale (ctx0, inp_per_layer, tok_embd_scale);
|
||||
cb(inp_per_layer, "inp_per_layer_selected", -1);
|
||||
res->add_input(std::move(inp));
|
||||
} else {
|
||||
// Vision embedding path: use padding token (ID=0) embedding
|
||||
// Multimodal embedding path: use padding token (ID=0) embedding
|
||||
// TODO: verify if this is the correct behavior in transformers implementation
|
||||
const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_altup * n_layer
|
||||
const int64_t embd_size = model.per_layer_tok_embd->ne[0]; // n_embd_altup * n_layer
|
||||
|
||||
// Extract and dequantize padding token embedding (row 0)
|
||||
ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
|
||||
inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32);
|
||||
ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0);
|
||||
inp_per_layer = ggml_cast (ctx0, padding, GGML_TYPE_F32);
|
||||
inp_per_layer = ggml_scale(ctx0, inp_per_layer, tok_embd_scale);
|
||||
|
||||
// Reshape to [n_embd_altup, n_layer, 1]
|
||||
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, 1);
|
||||
cb(inp_per_layer, "inp_per_layer_vision", -1);
|
||||
cb(inp_per_layer, "inp_per_layer_multimodal", -1);
|
||||
}
|
||||
return inp_per_layer;
|
||||
}
|
||||
@@ -275,18 +280,19 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
|
||||
// equivalent to project_per_layer_inputs() in python code
|
||||
// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
|
||||
// output shape: [n_embd_altup, n_tokens, n_layer]
|
||||
ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
|
||||
ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) {
|
||||
const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd);
|
||||
const float per_layer_input_scale = 1.0f / sqrtf(2.0f);
|
||||
|
||||
ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
|
||||
per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
|
||||
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
|
||||
per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS,
|
||||
-1); // [n_embd_altup, n_layer, n_tokens]
|
||||
ggml_tensor * per_layer_proj;
|
||||
per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj, inp_batch);
|
||||
per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale);
|
||||
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
|
||||
|
||||
per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS, -1);
|
||||
cb(per_layer_proj, "per_layer_proj", -1);
|
||||
|
||||
inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer);
|
||||
inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer);
|
||||
inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
|
||||
cb(inp_per_layer, "inp_per_layer", -1);
|
||||
|
||||
@@ -337,7 +343,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tenso
|
||||
// input cur shape: [n_embd, n_tokens, n_altup]
|
||||
// output shape: [n_embd, n_tokens, n_altup]
|
||||
ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) {
|
||||
ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens]
|
||||
ggml_tensor * activated = ggml_view_2d_slice(ctx0, cur, i_altup_act); // [n_embd, n_tokens]
|
||||
ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
|
||||
cb(modalities, "modalities", il);
|
||||
|
||||
@@ -365,7 +371,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, g
|
||||
ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
|
||||
cb(modalities, "modalities", il);
|
||||
|
||||
ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act);
|
||||
ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act);
|
||||
ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens]
|
||||
cb(innovation, "innovation", il);
|
||||
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
#include "models.h"
|
||||
|
||||
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
|
||||
static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) {
|
||||
GGML_ASSERT(idx < (int) x->ne[2]);
|
||||
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
|
||||
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
|
||||
}
|
||||
|
||||
llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params),
|
||||
model(model),
|
||||
@@ -19,14 +26,17 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll
|
||||
// TODO: is causal == true correct? might need some changes
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
|
||||
// inp_per_layer shape: [n_embd_per_layer, n_tokens, n_layer]
|
||||
ggml_tensor * inp_per_layer = nullptr;
|
||||
if (model.tok_embd_per_layer) {
|
||||
inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
|
||||
}
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
ggml_tensor * inp_per_layer = nullptr;
|
||||
if (model.per_layer_tok_embd) {
|
||||
inp_per_layer = build_inp_per_layer();
|
||||
ggml_build_forward_expand(gf, inp_per_layer);
|
||||
|
||||
// inp_per_layer shape: [n_embd_per_layer, n_tokens, n_layer]
|
||||
inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer);
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k(il);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v(il));
|
||||
@@ -196,7 +206,8 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll
|
||||
|
||||
cur = build_lora_mm(model.layers[il].per_layer_inp_gate, cur); // [n_embd_per_layer, n_tokens]
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_per_layer, n_tokens]
|
||||
|
||||
ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_per_layer, n_tokens]
|
||||
|
||||
// TODO @ngxson : improve this
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -248,60 +259,60 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
|
||||
ggml_tensor * llm_build_gemma4_iswa::view_2d_slice(ggml_tensor * x, int idx) {
|
||||
GGML_ASSERT(idx < (int) x->ne[2]);
|
||||
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
|
||||
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
|
||||
}
|
||||
|
||||
// equivalent to get_per_layer_inputs() in python code
|
||||
// output shape: [n_embd_per_layer, n_layer, n_tokens]
|
||||
ggml_tensor * llm_build_gemma4_iswa::get_per_layer_inputs() {
|
||||
ggml_tensor * llm_build_gemma4_iswa::build_inp_per_layer() {
|
||||
auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
|
||||
|
||||
ggml_tensor * inp_per_layer;
|
||||
float tok_embd_scale = sqrtf((float) n_embd_per_layer);
|
||||
if (ubatch.token) {
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
ggml_set_input(inp->tokens);
|
||||
res->t_inp_tokens = inp->tokens;
|
||||
inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
|
||||
|
||||
inp_per_layer = ggml_get_rows (ctx0, model.per_layer_tok_embd, inp->tokens);
|
||||
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_per_layer, n_layer, n_tokens);
|
||||
inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_per_layer));
|
||||
inp_per_layer = ggml_scale (ctx0, inp_per_layer, tok_embd_scale);
|
||||
cb(inp_per_layer, "inp_per_layer_selected", -1);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
} else {
|
||||
// Vision embedding path: use padding token (ID=0) embedding
|
||||
// Multimodal embedding path: use padding token (ID=0) embedding
|
||||
// TODO: verify if this is the correct behavior in transformers implementation
|
||||
const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_per_layer * n_layer
|
||||
const int64_t embd_size = model.per_layer_tok_embd->ne[0]; // n_embd_per_layer * n_layer
|
||||
|
||||
// Extract and dequantize padding token embedding (row 0)
|
||||
ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
|
||||
inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32);
|
||||
ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0);
|
||||
inp_per_layer = ggml_cast (ctx0, padding, GGML_TYPE_F32);
|
||||
inp_per_layer = ggml_scale(ctx0, inp_per_layer, tok_embd_scale);
|
||||
|
||||
// Reshape to [n_embd_per_layer, n_layer, 1]
|
||||
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_per_layer, n_layer, 1);
|
||||
cb(inp_per_layer, "inp_per_layer_vision", -1);
|
||||
cb(inp_per_layer, "inp_per_layer_multimodal", -1);
|
||||
}
|
||||
return inp_per_layer;
|
||||
}
|
||||
|
||||
// equivalent to project_per_layer_inputs() in python code
|
||||
// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
|
||||
// inputs_embeds shape: [n_embd, n_tokens]
|
||||
// inp_per_layer shape: [n_embd_per_layer, n_layer, n_tokens] (from get_per_layer_inputs)
|
||||
// inp_batch shape: [n_embd, n_tokens]
|
||||
// inp_per_layer shape: [n_embd_per_layer, n_layer, n_tokens] (from build_inp_per_layer)
|
||||
// output shape: [n_embd_per_layer, n_tokens, n_layer]
|
||||
ggml_tensor * llm_build_gemma4_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
|
||||
ggml_tensor * llm_build_gemma4_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) {
|
||||
const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd);
|
||||
const float per_layer_input_scale = 1.0f / sqrtf(2.0f);
|
||||
|
||||
ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
|
||||
per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
|
||||
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_per_layer, n_layer, n_tokens);
|
||||
per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, nullptr, LLM_NORM_RMS,
|
||||
-1); // [n_embd_per_layer, n_layer, n_tokens]
|
||||
// note: this matrix multiplication will be performed in the input layer (i.e. on the CPU)
|
||||
ggml_tensor * per_layer_proj;
|
||||
per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj, inp_batch);
|
||||
per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale);
|
||||
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_per_layer, n_layer, n_tokens);
|
||||
|
||||
per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, nullptr, LLM_NORM_RMS, -1);
|
||||
cb(per_layer_proj, "per_layer_proj", -1);
|
||||
|
||||
inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer);
|
||||
inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer);
|
||||
inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
|
||||
cb(inp_per_layer, "inp_per_layer", -1);
|
||||
|
||||
|
||||
@@ -256,9 +256,11 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
|
||||
|
||||
llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params);
|
||||
ggml_tensor * calc_magnitude(ggml_tensor * x);
|
||||
ggml_tensor * view_2d_slice(ggml_tensor * x, int idx);
|
||||
ggml_tensor * get_per_layer_inputs();
|
||||
ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer);
|
||||
|
||||
// TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER]
|
||||
ggml_tensor * build_inp_per_layer();
|
||||
ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer);
|
||||
|
||||
ggml_tensor * gaussian_topk(ggml_tensor * x);
|
||||
ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il);
|
||||
ggml_tensor * altup_predict(ggml_tensor * cur, int il);
|
||||
@@ -272,9 +274,10 @@ struct llm_build_gemma4_iswa : public llm_graph_context {
|
||||
const int64_t n_embd_per_layer;
|
||||
|
||||
llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params);
|
||||
ggml_tensor * view_2d_slice(ggml_tensor * x, int idx);
|
||||
ggml_tensor * get_per_layer_inputs();
|
||||
ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer);
|
||||
|
||||
// TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER]
|
||||
ggml_tensor * build_inp_per_layer();
|
||||
ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer);
|
||||
};
|
||||
|
||||
struct llm_build_gemma_embedding : public llm_graph_context {
|
||||
|
||||
@@ -124,6 +124,7 @@ llama_test(test-tokenizer-0 NAME test-tokenizer-0-command-r ARGS ${PROJE
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-deepseek-coder ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-deepseek-coder.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-deepseek-llm ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-deepseek-llm.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-falcon ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-falcon.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-gemma-4 ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-gemma-4.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-gpt-2 ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-gpt-2.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-llama-bpe ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-llama-bpe.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-llama-spm ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-llama-spm.gguf)
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
import { readFileSync } from "fs"
|
||||
import { SchemaConverter } from "../tools/server/public_legacy/json-schema-to-grammar.mjs"
|
||||
|
||||
const [, , file] = process.argv
|
||||
const url = `file://${file}`
|
||||
let schema = JSON.parse(readFileSync(file, "utf8"));
|
||||
const converter = new SchemaConverter({})
|
||||
schema = await converter.resolveRefs(schema, url)
|
||||
converter.visit(schema, '')
|
||||
console.log(converter.formatGrammar())
|
||||
@@ -7251,6 +7251,7 @@ static const ggml_type all_types[] = {
|
||||
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
|
||||
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
|
||||
GGML_TYPE_Q8_0,
|
||||
GGML_TYPE_Q1_0,
|
||||
GGML_TYPE_MXFP4, GGML_TYPE_NVFP4,
|
||||
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
|
||||
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
|
||||
@@ -7275,6 +7276,7 @@ static const ggml_type other_types[] = {
|
||||
GGML_TYPE_Q4_1,
|
||||
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
|
||||
GGML_TYPE_Q8_0,
|
||||
GGML_TYPE_Q1_0,
|
||||
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
|
||||
GGML_TYPE_Q5_K,
|
||||
GGML_TYPE_Q6_K,
|
||||
|
||||
@@ -998,6 +998,7 @@ static void test_peg_parser(common_chat_templates * tmpls,
|
||||
auto parser = make_peg_parser(tmpls, tc.params, detailed_debug);
|
||||
if (detailed_debug) {
|
||||
LOG_DBG("Using parser: \n%s\n", parser.arena_.dump(parser.arena_.root()).c_str());
|
||||
LOG_DBG("Generation prompt: '%s'\n", parser.params_.generation_prompt.c_str());
|
||||
}
|
||||
|
||||
common_chat_msg msg_accum;
|
||||
@@ -3102,8 +3103,19 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
// Format: <minimax:tool_call><invoke name="func"><parameter name="key">value</parameter></invoke></minimax:tool_call>
|
||||
{
|
||||
auto tst = peg_tester("models/templates/MiniMax-M2.jinja", detailed_debug);
|
||||
tst.test("</think>Hello, world!\nWhat's up?").enable_thinking(true).reasoning_format(COMMON_REASONING_FORMAT_AUTO).expect(message_assist).run();
|
||||
|
||||
tst.test("I'm\nthinking</think>Hello, world!\nWhat's up?").enable_thinking(true).reasoning_format(COMMON_REASONING_FORMAT_AUTO).expect(message_assist_thoughts).run();
|
||||
|
||||
tst.test("Let's call a tool:</think><minimax:tool_call>\n<invoke name=\"empty_args\">\n</invoke>\n</minimax:tool_call>").
|
||||
enable_thinking(true).
|
||||
reasoning_format(COMMON_REASONING_FORMAT_AUTO).
|
||||
tools({ empty_args_tool }).
|
||||
expect(message_with_reasoning_and_tool_call("Let's call a tool:", "empty_args", "{}")).
|
||||
run();
|
||||
|
||||
tst.test(
|
||||
"<minimax:tool_call>\n<invoke name=\"special_function\">\n<parameter "
|
||||
"</think><minimax:tool_call>\n<invoke name=\"special_function\">\n<parameter "
|
||||
"name=\"arg1\">1</parameter>\n</invoke>\n</minimax:tool_call>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
@@ -3442,7 +3454,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
},
|
||||
"replaceAll": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to replace all occurences."
|
||||
"description": "Whether to replace all occurrences."
|
||||
}
|
||||
},
|
||||
"required": ["oldString", "newString"]
|
||||
|
||||
@@ -447,6 +447,18 @@ static void test_expressions(testing & t) {
|
||||
"hello world"
|
||||
);
|
||||
|
||||
test_template(t, "string repetition",
|
||||
"{{ 'ab' * 3 }}",
|
||||
json::object(),
|
||||
"ababab"
|
||||
);
|
||||
|
||||
test_template(t, "reversed string repetition",
|
||||
"{{ 3 * 'ab' }}",
|
||||
json::object(),
|
||||
"ababab"
|
||||
);
|
||||
|
||||
test_template(t, "ternary",
|
||||
"{{ 'yes' if cond else 'no' }}",
|
||||
{{"cond", true}},
|
||||
@@ -693,6 +705,33 @@ static void test_filters(testing & t) {
|
||||
"\"\\u2713\""
|
||||
);
|
||||
|
||||
test_template(t, "tojson ensure_ascii=true nested object",
|
||||
"{{ data|tojson(ensure_ascii=true) }}",
|
||||
{{"data", {
|
||||
{"text", "\u2713"},
|
||||
{"items", json::array({"é", {{"snowman", "☃"}}})}
|
||||
}}},
|
||||
"{\"text\": \"\\u2713\", \"items\": [\"\\u00e9\", {\"snowman\": \"\\u2603\"}]}"
|
||||
);
|
||||
|
||||
test_template(t, "tojson ensure_ascii=true indent=2",
|
||||
"{{ data|tojson(ensure_ascii=true, indent=2) }}",
|
||||
{{"data", {
|
||||
{"text", "\u2713"},
|
||||
{"nested", {{"accent", "é"}}}
|
||||
}}},
|
||||
"{\n \"text\": \"\\u2713\",\n \"nested\": {\n \"accent\": \"\\u00e9\"\n }\n}"
|
||||
);
|
||||
|
||||
test_template(t, "tojson ensure_ascii=true preserves existing escapes",
|
||||
"{{ data|tojson(ensure_ascii=true) }}",
|
||||
{{"data", {
|
||||
{"emoji", "😀"},
|
||||
{"line", "a\nb"}
|
||||
}}},
|
||||
"{\"emoji\": \"\\ud83d\\ude00\", \"line\": \"a\\nb\"}"
|
||||
);
|
||||
|
||||
test_template(t, "tojson sort_keys=true",
|
||||
"{{ data|tojson(sort_keys=true) }}",
|
||||
{{"data", {{"b", 2}, {"a", 1}}}},
|
||||
@@ -771,6 +810,12 @@ static void test_filters(testing & t) {
|
||||
"hello"
|
||||
);
|
||||
|
||||
test_template(t, "int filter on integer is identity",
|
||||
"{{ value|int }}",
|
||||
{{"value", 7}},
|
||||
"7"
|
||||
);
|
||||
|
||||
test_template(t, "none to string",
|
||||
"{{ x|string }}",
|
||||
{{"x", nullptr}},
|
||||
@@ -2458,4 +2503,12 @@ static void test_fuzzing(testing & t) {
|
||||
t.assert_true("builtin " + type_name + "." + fn_name + " #" + std::to_string(i), fuzz_test_template(tmpl, vars));
|
||||
}
|
||||
});
|
||||
|
||||
t.test("tojson ensure_ascii=true with invalid utf-8", [&](testing & t) {
|
||||
t.assert_true("invalid utf-8 does not crash",
|
||||
fuzz_test_template(
|
||||
"{{ data|tojson(ensure_ascii=true) }}",
|
||||
{{"data", std::string("hello\xfe\xffworld")}}
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1579,17 +1579,6 @@ int main() {
|
||||
} else {
|
||||
fprintf(stderr, "\033[33mWARNING: Python not found (min version required is 3.8), skipping Python JSON schema -> grammar tests.\n\033[0m");
|
||||
}
|
||||
|
||||
if (getenv("LLAMA_NODE_AVAILABLE") || (std::system("node --version") == 0)) {
|
||||
test_all("JavaScript", [](const TestCase & tc) {
|
||||
write("test-json-schema-input.tmp", tc.schema);
|
||||
tc.verify_status(std::system(
|
||||
"node ./tests/run-json-schema-to-grammar.mjs test-json-schema-input.tmp > test-grammar-output.tmp") == 0 ? SUCCESS : FAILURE);
|
||||
tc.verify(read("test-grammar-output.tmp"));
|
||||
});
|
||||
} else {
|
||||
fprintf(stderr, "\033[33mWARNING: Node not found, skipping JavaScript JSON schema -> grammar tests.\n\033[0m");
|
||||
}
|
||||
}
|
||||
|
||||
test_all("Check Expectations Validity", [](const TestCase & tc) {
|
||||
|
||||
@@ -19,7 +19,7 @@ with open(fname_tok, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
s = ''.join(lines)
|
||||
t_start = time.time()
|
||||
res = tokenizer.encode(s, add_special_tokens=False)
|
||||
res = tokenizer.encode(s, add_special_tokens=False) # ty: ignore[unresolved-attribute]
|
||||
t_end = time.time()
|
||||
print('\nmain : tokenized in', "{:.3f}".format(1000.0 * (t_end - t_start)), 'ms (py)') # noqa: NP100
|
||||
with open(fname_out, 'w', encoding='utf-8') as f:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user