mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-16 16:27:32 +03:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a29e4c0b7b | ||
|
|
b136b62cf9 | ||
|
|
81069a808a | ||
|
|
9aa2807769 | ||
|
|
3fc65063d9 | ||
|
|
05b3caaa48 | ||
|
|
e62fa13c24 | ||
|
|
bfd1f453cb | ||
|
|
e4fed9d08d | ||
|
|
5dd102539b | ||
|
|
fb38d6f278 |
17
cmake/arm64-linux-clang.cmake
Normal file
17
cmake/arm64-linux-clang.cmake
Normal file
@@ -0,0 +1,17 @@
|
||||
set( CMAKE_SYSTEM_NAME Linux )
|
||||
set( CMAKE_SYSTEM_PROCESSOR arm64 )
|
||||
|
||||
set( target aarch64-linux-gnu )
|
||||
|
||||
set( CMAKE_C_COMPILER clang )
|
||||
set( CMAKE_CXX_COMPILER clang++ )
|
||||
|
||||
set( CMAKE_C_COMPILER_TARGET ${target} )
|
||||
set( CMAKE_CXX_COMPILER_TARGET ${target} )
|
||||
|
||||
set( arch_c_flags "-march=armv8.7-a -fvectorize -ffp-model=fast -fno-finite-math-only" )
|
||||
set( warn_c_flags "-Wno-format -Wno-unused-variable -Wno-unused-function -Wno-gnu-zero-variadic-macro-arguments" )
|
||||
|
||||
set( CMAKE_C_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" )
|
||||
set( CMAKE_CXX_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" )
|
||||
|
||||
@@ -291,14 +291,16 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa
|
||||
hf_tag = "default";
|
||||
}
|
||||
|
||||
const bool offline = params.offline;
|
||||
std::string model_endpoint = get_model_endpoint();
|
||||
auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini";
|
||||
|
||||
// prepare local path for caching
|
||||
auto preset_fname = clean_file_name(hf_repo + "_preset.ini");
|
||||
auto preset_path = fs_get_cache_file(preset_fname);
|
||||
const int status = common_download_file_single(preset_url, preset_path, params.hf_token, offline);
|
||||
common_download_opts opts;
|
||||
opts.bearer_token = params.hf_token;
|
||||
opts.offline = params.offline;
|
||||
const int status = common_download_file_single(preset_url, preset_path, opts);
|
||||
const bool has_preset = status >= 200 && status < 400;
|
||||
|
||||
// remote preset is optional, so we don't error out if not found
|
||||
@@ -341,10 +343,10 @@ static handle_model_result common_params_handle_model(struct common_params_model
|
||||
model.hf_file = model.path;
|
||||
model.path = "";
|
||||
}
|
||||
common_download_model_opts opts;
|
||||
opts.download_mmproj = true;
|
||||
common_download_opts opts;
|
||||
opts.bearer_token = bearer_token;
|
||||
opts.offline = offline;
|
||||
auto download_result = common_download_model(model, bearer_token, opts);
|
||||
auto download_result = common_download_model(model, opts, true);
|
||||
|
||||
if (download_result.model_path.empty()) {
|
||||
LOG_ERR("error: failed to download model from Hugging Face\n");
|
||||
@@ -365,9 +367,10 @@ static handle_model_result common_params_handle_model(struct common_params_model
|
||||
model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
|
||||
}
|
||||
|
||||
common_download_model_opts opts;
|
||||
common_download_opts opts;
|
||||
opts.bearer_token = bearer_token;
|
||||
opts.offline = offline;
|
||||
auto download_result = common_download_model(model, bearer_token, opts);
|
||||
auto download_result = common_download_model(model, opts);
|
||||
if (download_result.model_path.empty()) {
|
||||
LOG_ERR("error: failed to download model from %s\n", model.url.c_str());
|
||||
exit(1);
|
||||
|
||||
@@ -69,6 +69,10 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
|
||||
auto schema = function.contains("parameters") ? function.at("parameters") : json::object();
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
if (has_response_format) {
|
||||
auto schema = inputs.json_schema;
|
||||
builder.resolve_refs(schema);
|
||||
}
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
|
||||
@@ -865,9 +865,10 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
||||
adjusted_messages.push_back(adjusted);
|
||||
}
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = true;
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = true;
|
||||
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "[THINK]";
|
||||
@@ -887,7 +888,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
||||
extract_reasoning ? p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]") : p.eps();
|
||||
|
||||
// Response format parser
|
||||
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
|
||||
if (has_response_format) {
|
||||
// Ministral wants to emit json surrounded by code fences
|
||||
return generation_prompt + (reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```");
|
||||
}
|
||||
@@ -928,6 +929,10 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
||||
auto schema = function.at("parameters");
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
if (has_response_format) {
|
||||
auto schema = inputs.json_schema;
|
||||
builder.resolve_refs(schema);
|
||||
}
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
@@ -1063,6 +1068,10 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
auto schema = function.at("parameters");
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
if (has_response_format) {
|
||||
auto schema = inputs.json_schema;
|
||||
builder.resolve_refs(schema);
|
||||
}
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
@@ -1193,6 +1202,10 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
auto schema = function.at("parameters");
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
if (has_response_format) {
|
||||
auto schema = inputs.json_schema;
|
||||
builder.resolve_refs(schema);
|
||||
}
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
@@ -1916,7 +1929,12 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
|
||||
// Gemma4 format detection
|
||||
if (src.find("'<|tool_call>call:'") != std::string::npos) {
|
||||
workaround::convert_tool_responses_gemma4(params.messages);
|
||||
if (src.find("{#- OpenAI Chat Completions:") == std::string::npos) {
|
||||
// apply workarounds if using the older gemma4 templates
|
||||
LOG_WRN("%s: detected an outdated gemma4 chat template, applying compatibility workarounds. "
|
||||
"Consider updating to the official template.\n", __func__);
|
||||
workaround::convert_tool_responses_gemma4(params.messages);
|
||||
}
|
||||
return common_chat_params_init_gemma4(tmpl, params);
|
||||
}
|
||||
|
||||
|
||||
@@ -114,7 +114,7 @@ std::pair<std::string, std::string> common_download_split_repo_tag(const std::st
|
||||
return {hf_repo, tag};
|
||||
}
|
||||
|
||||
class ProgressBar {
|
||||
class ProgressBar : public common_download_callback {
|
||||
static inline std::mutex mutex;
|
||||
static inline std::map<const ProgressBar *, int> lines;
|
||||
static inline int max_line = 0;
|
||||
@@ -138,7 +138,11 @@ class ProgressBar {
|
||||
}
|
||||
|
||||
public:
|
||||
ProgressBar(const std::string & url = "") : filename(url) {
|
||||
ProgressBar() = default;
|
||||
|
||||
void on_start(const common_download_progress & p) override {
|
||||
filename = p.url;
|
||||
|
||||
if (auto pos = filename.rfind('/'); pos != std::string::npos) {
|
||||
filename = filename.substr(pos + 1);
|
||||
}
|
||||
@@ -156,13 +160,13 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
~ProgressBar() {
|
||||
void on_done(const common_download_progress &, bool) override {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
cleanup(this);
|
||||
}
|
||||
|
||||
void update(size_t current, size_t total) {
|
||||
if (!total || !is_output_a_tty()) {
|
||||
void on_update(const common_download_progress & p) override {
|
||||
if (!p.total || !is_output_a_tty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -175,8 +179,8 @@ public:
|
||||
int lines_up = max_line - lines[this];
|
||||
|
||||
size_t bar = (55 - len) * 2;
|
||||
size_t pct = (100 * current) / total;
|
||||
size_t pos = (bar * current) / total;
|
||||
size_t pct = (100 * p.downloaded) / p.total;
|
||||
size_t pos = (bar * p.downloaded) / p.total;
|
||||
|
||||
if (lines_up > 0) {
|
||||
std::cout << "\033[" << lines_up << "A";
|
||||
@@ -193,7 +197,7 @@ public:
|
||||
}
|
||||
std::cout << '\r' << std::flush;
|
||||
|
||||
if (current == total) {
|
||||
if (p.downloaded == p.total) {
|
||||
cleanup(this);
|
||||
}
|
||||
}
|
||||
@@ -206,8 +210,8 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
const std::string & resolve_path,
|
||||
const std::string & path_tmp,
|
||||
bool supports_ranges,
|
||||
size_t existing_size,
|
||||
size_t & total_size) {
|
||||
common_download_progress & p,
|
||||
common_download_callback * callback) {
|
||||
std::ofstream ofs(path_tmp, std::ios::binary | std::ios::app);
|
||||
if (!ofs.is_open()) {
|
||||
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_tmp.c_str());
|
||||
@@ -215,29 +219,27 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
}
|
||||
|
||||
httplib::Headers headers;
|
||||
if (supports_ranges && existing_size > 0) {
|
||||
headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-");
|
||||
if (supports_ranges && p.downloaded > 0) {
|
||||
headers.emplace("Range", "bytes=" + std::to_string(p.downloaded) + "-");
|
||||
}
|
||||
|
||||
const char * func = __func__; // avoid __func__ inside a lambda
|
||||
size_t downloaded = existing_size;
|
||||
size_t progress_step = 0;
|
||||
ProgressBar bar(resolve_path);
|
||||
|
||||
auto res = cli.Get(resolve_path, headers,
|
||||
[&](const httplib::Response &response) {
|
||||
if (existing_size > 0 && response.status != 206) {
|
||||
if (p.downloaded > 0 && response.status != 206) {
|
||||
LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", func, response.status);
|
||||
return false;
|
||||
}
|
||||
if (existing_size == 0 && response.status != 200) {
|
||||
if (p.downloaded == 0 && response.status != 200) {
|
||||
LOG_WRN("%s: download received non-successful status code: %d\n", func, response.status);
|
||||
return false;
|
||||
}
|
||||
if (total_size == 0 && response.has_header("Content-Length")) {
|
||||
if (p.total == 0 && response.has_header("Content-Length")) {
|
||||
try {
|
||||
size_t content_length = std::stoull(response.get_header_value("Content-Length"));
|
||||
total_size = existing_size + content_length;
|
||||
p.total = p.downloaded + content_length;
|
||||
} catch (const std::exception &e) {
|
||||
LOG_WRN("%s: invalid Content-Length header: %s\n", func, e.what());
|
||||
}
|
||||
@@ -250,11 +252,13 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
LOG_ERR("%s: error writing to file: %s\n", func, path_tmp.c_str());
|
||||
return false;
|
||||
}
|
||||
downloaded += len;
|
||||
p.downloaded += len;
|
||||
progress_step += len;
|
||||
|
||||
if (progress_step >= total_size / 1000 || downloaded == total_size) {
|
||||
bar.update(downloaded, total_size);
|
||||
if (progress_step >= p.total / 1000 || p.downloaded == p.total) {
|
||||
if (callback) {
|
||||
callback->on_update(p);
|
||||
}
|
||||
progress_step = 0;
|
||||
}
|
||||
return true;
|
||||
@@ -275,28 +279,13 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
|
||||
// download one single file from remote URL to local path
|
||||
// returns status code or -1 on error
|
||||
static int common_download_file_single_online(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
const common_header_list & custom_headers,
|
||||
bool skip_etag = false) {
|
||||
static int common_download_file_single_online(const std::string & url,
|
||||
const std::string & path,
|
||||
const common_download_opts & opts,
|
||||
bool skip_etag) {
|
||||
static const int max_attempts = 3;
|
||||
static const int retry_delay_seconds = 2;
|
||||
|
||||
auto [cli, parts] = common_http_client(url);
|
||||
|
||||
httplib::Headers headers;
|
||||
for (const auto & h : custom_headers) {
|
||||
headers.emplace(h.first, h.second);
|
||||
}
|
||||
if (headers.find("User-Agent") == headers.end()) {
|
||||
headers.emplace("User-Agent", "llama-cpp/" + build_info);
|
||||
}
|
||||
if (!bearer_token.empty()) {
|
||||
headers.emplace("Authorization", "Bearer " + bearer_token);
|
||||
}
|
||||
cli.set_default_headers(headers);
|
||||
|
||||
const bool file_exists = std::filesystem::exists(path);
|
||||
|
||||
if (file_exists && skip_etag) {
|
||||
@@ -304,6 +293,20 @@ static int common_download_file_single_online(const std::string & url,
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
|
||||
auto [cli, parts] = common_http_client(url);
|
||||
|
||||
httplib::Headers headers;
|
||||
for (const auto & h : opts.headers) {
|
||||
headers.emplace(h.first, h.second);
|
||||
}
|
||||
if (headers.find("User-Agent") == headers.end()) {
|
||||
headers.emplace("User-Agent", "llama-cpp/" + build_info);
|
||||
}
|
||||
if (!opts.bearer_token.empty()) {
|
||||
headers.emplace("Authorization", "Bearer " + opts.bearer_token);
|
||||
}
|
||||
cli.set_default_headers(headers);
|
||||
|
||||
std::string last_etag;
|
||||
if (file_exists) {
|
||||
last_etag = read_etag(path);
|
||||
@@ -326,10 +329,11 @@ static int common_download_file_single_online(const std::string & url,
|
||||
etag = head->get_header_value("ETag");
|
||||
}
|
||||
|
||||
size_t total_size = 0;
|
||||
common_download_progress p;
|
||||
p.url = url;
|
||||
if (head->has_header("Content-Length")) {
|
||||
try {
|
||||
total_size = std::stoull(head->get_header_value("Content-Length"));
|
||||
p.total = std::stoull(head->get_header_value("Content-Length"));
|
||||
} catch (const std::exception& e) {
|
||||
LOG_WRN("%s: invalid Content-Length in HEAD response: %s\n", __func__, e.what());
|
||||
}
|
||||
@@ -357,13 +361,17 @@ static int common_download_file_single_online(const std::string & url,
|
||||
|
||||
{ // silent
|
||||
std::error_code ec;
|
||||
std::filesystem::path p(path);
|
||||
std::filesystem::create_directories(p.parent_path(), ec);
|
||||
std::filesystem::create_directories(std::filesystem::path(path).parent_path(), ec);
|
||||
}
|
||||
|
||||
bool success = false;
|
||||
const std::string path_temporary = path + ".downloadInProgress";
|
||||
int delay = retry_delay_seconds;
|
||||
|
||||
if (opts.callback) {
|
||||
opts.callback->on_start(p);
|
||||
}
|
||||
|
||||
for (int i = 0; i < max_attempts; ++i) {
|
||||
if (i) {
|
||||
LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay);
|
||||
@@ -378,28 +386,38 @@ static int common_download_file_single_online(const std::string & url,
|
||||
existing_size = std::filesystem::file_size(path_temporary);
|
||||
} else if (remove(path_temporary.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
|
||||
return -1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
p.downloaded = existing_size;
|
||||
|
||||
LOG_DBG("%s: downloading from %s to %s (etag:%s)...\n",
|
||||
__func__, common_http_show_masked_url(parts).c_str(),
|
||||
path_temporary.c_str(), etag.c_str());
|
||||
|
||||
if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size)) {
|
||||
if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, p, opts.callback)) {
|
||||
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||
return -1;
|
||||
break;
|
||||
}
|
||||
if (!etag.empty() && !skip_etag) {
|
||||
write_etag(path, etag);
|
||||
}
|
||||
return head->status;
|
||||
success = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
|
||||
return -1; // max attempts reached
|
||||
if (opts.callback) {
|
||||
opts.callback->on_done(p, success);
|
||||
}
|
||||
if (!success) {
|
||||
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
|
||||
return -1; // max attempts reached
|
||||
}
|
||||
|
||||
return head->status;
|
||||
}
|
||||
|
||||
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
|
||||
@@ -438,12 +456,15 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
|
||||
|
||||
int common_download_file_single(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
bool offline,
|
||||
const common_header_list & headers,
|
||||
const common_download_opts & opts,
|
||||
bool skip_etag) {
|
||||
if (!offline) {
|
||||
return common_download_file_single_online(url, path, bearer_token, headers, skip_etag);
|
||||
if (!opts.offline) {
|
||||
ProgressBar tty_cb;
|
||||
common_download_opts online_opts = opts;
|
||||
if (!online_opts.callback) {
|
||||
online_opts.callback = &tty_cb;
|
||||
}
|
||||
return common_download_file_single_online(url, path, online_opts, skip_etag);
|
||||
}
|
||||
|
||||
if (!std::filesystem::exists(path)) {
|
||||
@@ -452,6 +473,16 @@ int common_download_file_single(const std::string & url,
|
||||
}
|
||||
|
||||
LOG_DBG("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
|
||||
|
||||
// notify the callback that the file was cached
|
||||
if (opts.callback) {
|
||||
common_download_progress p;
|
||||
p.url = url;
|
||||
p.cached = true;
|
||||
opts.callback->on_start(p);
|
||||
opts.callback->on_done(p, true);
|
||||
}
|
||||
|
||||
return 304; // Not Modified - fake cached response
|
||||
}
|
||||
|
||||
@@ -631,16 +662,16 @@ struct hf_plan {
|
||||
hf_cache::hf_file mmproj;
|
||||
};
|
||||
|
||||
static hf_plan get_hf_plan(const common_params_model & model,
|
||||
const std::string & token,
|
||||
const common_download_model_opts & opts) {
|
||||
static hf_plan get_hf_plan(const common_params_model & model,
|
||||
const common_download_opts & opts,
|
||||
bool download_mmproj) {
|
||||
hf_plan plan;
|
||||
hf_cache::hf_files all;
|
||||
|
||||
auto [repo, tag] = common_download_split_repo_tag(model.hf_repo);
|
||||
|
||||
if (!opts.offline) {
|
||||
all = hf_cache::get_repo_files(repo, token);
|
||||
all = hf_cache::get_repo_files(repo, opts.bearer_token);
|
||||
}
|
||||
if (all.empty()) {
|
||||
all = hf_cache::get_cached_files(repo);
|
||||
@@ -675,7 +706,7 @@ static hf_plan get_hf_plan(const common_params_model & model,
|
||||
plan.primary = primary;
|
||||
plan.model_files = get_split_files(all, primary);
|
||||
|
||||
if (opts.download_mmproj) {
|
||||
if (download_mmproj) {
|
||||
plan.mmproj = find_best_mmproj(all, primary.path);
|
||||
}
|
||||
|
||||
@@ -710,10 +741,9 @@ static std::vector<download_task> get_url_tasks(const common_params_model & mode
|
||||
return tasks;
|
||||
}
|
||||
|
||||
common_download_model_result common_download_model(const common_params_model & model,
|
||||
const std::string & bearer_token,
|
||||
const common_download_model_opts & opts,
|
||||
const common_header_list & headers) {
|
||||
common_download_model_result common_download_model(const common_params_model & model,
|
||||
const common_download_opts & opts,
|
||||
bool download_mmproj) {
|
||||
common_download_model_result result;
|
||||
std::vector<download_task> tasks;
|
||||
hf_plan hf;
|
||||
@@ -721,7 +751,7 @@ common_download_model_result common_download_model(const common_params_model
|
||||
bool is_hf = !model.hf_repo.empty();
|
||||
|
||||
if (is_hf) {
|
||||
hf = get_hf_plan(model, bearer_token, opts);
|
||||
hf = get_hf_plan(model, opts, download_mmproj);
|
||||
for (const auto & f : hf.model_files) {
|
||||
tasks.push_back({f.url, f.local_path});
|
||||
}
|
||||
@@ -742,8 +772,8 @@ common_download_model_result common_download_model(const common_params_model
|
||||
std::vector<std::future<bool>> futures;
|
||||
for (const auto & task : tasks) {
|
||||
futures.push_back(std::async(std::launch::async,
|
||||
[&task, &bearer_token, offline = opts.offline, &headers, is_hf]() {
|
||||
int status = common_download_file_single(task.url, task.path, bearer_token, offline, headers, is_hf);
|
||||
[&task, &opts, is_hf]() {
|
||||
int status = common_download_file_single(task.url, task.path, opts, is_hf);
|
||||
return is_http_status_ok(status);
|
||||
}
|
||||
));
|
||||
@@ -879,7 +909,9 @@ std::string common_docker_resolve_model(const std::string & docker) {
|
||||
std::string local_path = fs_get_cache_file(model_filename);
|
||||
|
||||
const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
|
||||
const int http_status = common_download_file_single(blob_url, local_path, token, false, {});
|
||||
common_download_opts opts;
|
||||
opts.bearer_token = token;
|
||||
const int http_status = common_download_file_single(blob_url, local_path, opts);
|
||||
if (!is_http_status_ok(http_status)) {
|
||||
throw std::runtime_error("Failed to download Docker Model");
|
||||
}
|
||||
|
||||
@@ -8,6 +8,21 @@ struct common_params_model;
|
||||
using common_header = std::pair<std::string, std::string>;
|
||||
using common_header_list = std::vector<common_header>;
|
||||
|
||||
struct common_download_progress {
|
||||
std::string url;
|
||||
size_t downloaded = 0;
|
||||
size_t total = 0;
|
||||
bool cached = false;
|
||||
};
|
||||
|
||||
class common_download_callback {
|
||||
public:
|
||||
virtual ~common_download_callback() = default;
|
||||
virtual void on_start(const common_download_progress & p) = 0;
|
||||
virtual void on_update(const common_download_progress & p) = 0;
|
||||
virtual void on_done(const common_download_progress & p, bool ok) = 0;
|
||||
};
|
||||
|
||||
struct common_remote_params {
|
||||
common_header_list headers;
|
||||
long timeout = 0; // in seconds, 0 means no timeout
|
||||
@@ -31,10 +46,12 @@ struct common_cached_model_info {
|
||||
}
|
||||
};
|
||||
|
||||
// Options for common_download_model
|
||||
struct common_download_model_opts {
|
||||
bool download_mmproj = false;
|
||||
bool offline = false;
|
||||
// Options for common_download_model and common_download_file_single
|
||||
struct common_download_opts {
|
||||
std::string bearer_token;
|
||||
common_header_list headers;
|
||||
bool offline = false;
|
||||
common_download_callback * callback = nullptr;
|
||||
};
|
||||
|
||||
// Result of common_download_model
|
||||
@@ -69,9 +86,8 @@ struct common_download_model_result {
|
||||
// returns result with model_path and mmproj_path (empty on failure)
|
||||
common_download_model_result common_download_model(
|
||||
const common_params_model & model,
|
||||
const std::string & bearer_token,
|
||||
const common_download_model_opts & opts = {},
|
||||
const common_header_list & headers = {}
|
||||
const common_download_opts & opts = {},
|
||||
bool download_mmproj = false
|
||||
);
|
||||
|
||||
// returns list of cached models
|
||||
@@ -82,9 +98,7 @@ std::vector<common_cached_model_info> common_list_cached_models();
|
||||
// skip_etag: if true, don't read/write .etag files (for HF cache where filename is the hash)
|
||||
int common_download_file_single(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
bool offline,
|
||||
const common_header_list & headers = {},
|
||||
const common_download_opts & opts = {},
|
||||
bool skip_etag = false);
|
||||
|
||||
// resolve and download model from Docker registry
|
||||
|
||||
@@ -52,10 +52,39 @@
|
||||
}
|
||||
},
|
||||
|
||||
{
|
||||
"name": "arm64-linux-snapdragon",
|
||||
"hidden": true,
|
||||
"architecture": { "value": "arm64", "strategy": "external" },
|
||||
"toolset": { "value": "host=x86_64", "strategy": "external" },
|
||||
"cacheVariables": {
|
||||
"CMAKE_TOOLCHAIN_FILE": "cmake/arm64-linux-clang.cmake",
|
||||
"CMAKE_C_FLAGS": "-march=armv8 -fno-finite-math-only -flto -D_GNU_SOURCE",
|
||||
"CMAKE_CXX_FLAGS": "-march=armv8 -fno-finite-math-only -flto -D_GNU_SOURCE",
|
||||
"CMAKE_C_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
"CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
"CMAKE_PREFIX_PATH": "$env{OPENCL_SDK_ROOT}",
|
||||
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
|
||||
"HEXAGON_TOOLS_ROOT": "$env{HEXAGON_TOOLS_ROOT}",
|
||||
"PREBUILT_LIB_DIR": "linux_aarch64",
|
||||
"GGML_OPENMP": "OFF",
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
"GGML_OPENCL": "OFF",
|
||||
"GGML_HEXAGON": "ON",
|
||||
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
|
||||
"LLAMA_OPENSSL": "OFF"
|
||||
}
|
||||
},
|
||||
|
||||
{ "name": "arm64-android-snapdragon-debug" , "inherits": [ "base", "arm64-android-snapdragon", "debug" ] },
|
||||
{ "name": "arm64-android-snapdragon-release", "inherits": [ "base", "arm64-android-snapdragon", "release" ] },
|
||||
|
||||
{ "name": "arm64-windows-snapdragon-debug" , "inherits": [ "base", "arm64-windows-snapdragon", "debug" ] },
|
||||
{ "name": "arm64-windows-snapdragon-release", "inherits": [ "base", "arm64-windows-snapdragon", "release" ] }
|
||||
{ "name": "arm64-windows-snapdragon-release", "inherits": [ "base", "arm64-windows-snapdragon", "release" ] },
|
||||
|
||||
{ "name": "arm64-linux-snapdragon-debug" , "inherits": [ "base", "arm64-linux-snapdragon", "debug" ] },
|
||||
{ "name": "arm64-linux-snapdragon-release", "inherits": [ "base", "arm64-linux-snapdragon", "release" ] }
|
||||
]
|
||||
}
|
||||
|
||||
@@ -236,10 +236,6 @@ build: 6a8cf8914 (6733)
|
||||
Controls whether the Hexagon backend allocates host buffers. By default, all buffers except for REPACK are host buffers.
|
||||
This option is required for testing Ops that require REPACK buffers (MUL_MAT and MUL_MAT_ID).
|
||||
|
||||
- `GGML_HEXAGON_EXPERIMENTAL=1`
|
||||
Controls whether the Hexagon backend enables experimental features.
|
||||
This option is required for enabling/testing experimental Ops (FLASH_ATTN_EXT).
|
||||
|
||||
- `GGML_HEXAGON_VERBOSE=1`
|
||||
Enables verbose logging of Ops from the backend. Example output:
|
||||
|
||||
@@ -259,11 +255,17 @@ build: 6a8cf8914 (6733)
|
||||
Allows enabling specific stages of the processing pipeline:
|
||||
|
||||
- `0x1` Enable Op Queue (i.e., queuing Ops into NPU)
|
||||
- `0x2` Enable Dynamic Quantizer (if needed for the Op)
|
||||
- `0x4` Enable Op Compute (MUL_MAT, etc.)
|
||||
- `0x2` Enable Op Compute (MUL_MAT, etc.)
|
||||
|
||||
Examples:
|
||||
|
||||
`GGML_HEXAGON_OPMASK=0x1 llama-completion ...` - Ops are enqueued but NPU-side processing is stubbed out
|
||||
`GGML_HEXAGON_OPMASK=0x3 llama-completion ...` - NPU performs dynamic quantization and skips the rest
|
||||
`GGML_HEXAGON_OPMASK=0x7 llama-completion ...` - Full queuing and processing of Ops (default)
|
||||
`GGML_HEXAGON_OPMASK=0x3 llama-completion ...` - Full queuing and processing of Ops (default)
|
||||
|
||||
- `GGML_HEXAGON_OPFILTER=regex`
|
||||
Allows filtering (disabling) Ops that match the regex pattern:
|
||||
|
||||
Examples:
|
||||
|
||||
`GGML_HEXAGON_OPFILTER="FLASH_ATTN_EXT" llama-completion ...` - Disable Flash Attention on Hexagon (falls back to CPU or GPU)
|
||||
`GGML_HEXAGON_OPFILTER="ADD\|SUB" llama-completion ...` - Disable ADD and SUB on Hexagon (fall back to CPU or GPU)
|
||||
|
||||
58
docs/backend/snapdragon/linux.md
Normal file
58
docs/backend/snapdragon/linux.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# Snapdragon-based Linux devices
|
||||
|
||||
## Docker Setup
|
||||
|
||||
The easiest way to build llama.cpp for a Snapdragon-based Linux device is using the toolchain Docker image (see [github.com/snapdragon-toolchain](https://github.com/snapdragon-toolchain)).
|
||||
This image includes OpenCL SDK, Hexagon SDK, CMake, and the ARM64 Linux cross-compilation toolchain.
|
||||
|
||||
Cross-compilation is supported on **Linux X86** hosts. The resulting binaries are deployed to and run on the target **Qualcomm Snapdragon ARM64 Linux** device.
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ docker run -it -u $(id -u):$(id -g) --volume $(pwd):/workspace --platform linux/amd64 ghcr.io/snapdragon-toolchain/arm64-linux:v0.1
|
||||
[d]/> cd /workspace
|
||||
```
|
||||
|
||||
Note: The rest of the **Linux** build process assumes that you're running inside the toolchain container.
|
||||
|
||||
|
||||
## How to Build
|
||||
|
||||
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
|
||||
|
||||
```
|
||||
[d]/workspace> cp docs/backend/snapdragon/CMakeUserPresets.json .
|
||||
|
||||
[d]/workspace> cmake --preset arm64-linux-snapdragon-release -B build-snapdragon
|
||||
|
||||
[d]/workspace> cmake --build build-snapdragon -j $(nproc)
|
||||
```
|
||||
|
||||
To generate an installable "package" simply use cmake --install, then zip it:
|
||||
|
||||
```
|
||||
[d]/workspace> cmake --install build-snapdragon --prefix pkg-snapdragon
|
||||
[d]/workspace> zip -r pkg-snapdragon.zip pkg-snapdragon
|
||||
```
|
||||
|
||||
## How to Install
|
||||
|
||||
For this step, you will deploy the built binaries and libraries to the target Linux device. Transfer `pkg-snapdragon.zip` to the target device, then unzip it and set up the environment variables:
|
||||
|
||||
```
|
||||
$ unzip pkg-snapdragon.zip
|
||||
$ cd pkg-snapdragon
|
||||
$ export LD_LIBRARY_PATH=./lib
|
||||
$ export ADSP_LIBRARY_PATH=./lib
|
||||
```
|
||||
|
||||
At this point, you should also download some models onto the device:
|
||||
|
||||
```
|
||||
$ wget https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q4_0.gguf
|
||||
```
|
||||
|
||||
## How to Run
|
||||
Next, since we have setup the environment variables, we can run the llama-cli with the Hexagon backends:
|
||||
```
|
||||
$ ./bin/llama-cli -m Llama-3.2-3B-Instruct-Q4_0.gguf --device HTP0 -ngl 99 -p "what is the most popular cookie in the world?"
|
||||
```
|
||||
@@ -1185,7 +1185,9 @@ struct ggml_cuda_graph {
|
||||
bool warmup_complete = false;
|
||||
struct node_properties {
|
||||
ggml_tensor node;
|
||||
void * node_src_data_ptrs[GGML_MAX_SRC];
|
||||
void * node_src_data_ptrs[GGML_MAX_SRC];
|
||||
int64_t node_src_ne[GGML_MAX_SRC][GGML_MAX_DIMS];
|
||||
size_t node_src_nb[GGML_MAX_SRC][GGML_MAX_DIMS];
|
||||
};
|
||||
std::vector<node_properties> node_props;
|
||||
|
||||
|
||||
@@ -3070,16 +3070,18 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
|
||||
ggml_cuda_graph::node_properties prop = {};
|
||||
memcpy(&prop.node, cgraph->nodes[i], sizeof(ggml_tensor));
|
||||
|
||||
// if the backend scheduler is making copies of CPU tensors, the src pointers can be the same but with different data, see:
|
||||
// https://github.com/ggml-org/llama.cpp/pull/21472#discussion_r3052235188
|
||||
for (int j = 0; j < GGML_MAX_SRC; ++j) {
|
||||
prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j] ? cgraph->nodes[i]->src[j]->data : nullptr;
|
||||
if (cgraph->nodes[i]->src[j]) {
|
||||
prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j]->data;
|
||||
memcpy(prop.node_src_ne[j], cgraph->nodes[i]->src[j]->ne, sizeof(prop.node_src_ne[j]));
|
||||
memcpy(prop.node_src_nb[j], cgraph->nodes[i]->src[j]->nb, sizeof(prop.node_src_nb[j]));
|
||||
}
|
||||
}
|
||||
|
||||
if (!res && memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) {
|
||||
if (res || memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) {
|
||||
graph->node_props[i] = prop;
|
||||
res = true;
|
||||
}
|
||||
graph->node_props[i] = prop;
|
||||
}
|
||||
|
||||
return res;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,59 +14,42 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
#define htp_act_preamble3 \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = src1->ne[0]; \
|
||||
const uint32_t ne11 = src1->ne[1]; \
|
||||
const uint32_t ne12 = src1->ne[2]; \
|
||||
const uint32_t ne13 = src1->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = src1->nb[0]; \
|
||||
const uint32_t nb11 = src1->nb[1]; \
|
||||
const uint32_t nb12 = src1->nb[2]; \
|
||||
const uint32_t nb13 = src1->nb[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
#define htp_act_preamble2 \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
#define htp_act_preamble \
|
||||
const struct htp_tensor * src0 = actx->octx->src[0]; \
|
||||
const struct htp_tensor * src1 = actx->octx->src[1]; \
|
||||
const struct htp_tensor * dst = actx->octx->dst; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t ne10 = src1 ? src1->ne[0] : 0; \
|
||||
const uint32_t ne11 = src1 ? src1->ne[1] : 0; \
|
||||
const uint32_t ne12 = src1 ? src1->ne[2] : 0; \
|
||||
const uint32_t ne13 = src1 ? src1->ne[3] : 0; \
|
||||
\
|
||||
const uint32_t nb10 = src1 ? src1->nb[0] : 0; \
|
||||
const uint32_t nb11 = src1 ? src1->nb[1] : 0; \
|
||||
const uint32_t nb12 = src1 ? src1->nb[2] : 0; \
|
||||
const uint32_t nb13 = src1 ? src1->nb[3] : 0; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
struct htp_act_context {
|
||||
@@ -97,10 +80,7 @@ struct htp_act_context {
|
||||
|
||||
static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * src1 = &actx->octx->src1;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble3;
|
||||
htp_act_preamble;
|
||||
|
||||
size_t src0_row_size = actx->src0_row_size;
|
||||
size_t src1_row_size = actx->src1_row_size;
|
||||
@@ -207,10 +187,7 @@ static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
|
||||
static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * src1 = &actx->octx->src1;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble3;
|
||||
htp_act_preamble;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
@@ -332,9 +309,7 @@ static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, vo
|
||||
|
||||
static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble2;
|
||||
htp_act_preamble;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
@@ -433,9 +408,7 @@ static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
|
||||
static void unary_silu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble2;
|
||||
htp_act_preamble;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
@@ -533,10 +506,7 @@ static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * src1 = &actx->octx->src1;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble3;
|
||||
htp_act_preamble;
|
||||
|
||||
size_t src0_row_size = actx->src0_row_size;
|
||||
size_t src1_row_size = actx->src1_row_size;
|
||||
@@ -652,9 +622,9 @@ static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
}
|
||||
|
||||
static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
if (((src0->ne[0] * SIZEOF_FP32) != src0->nb[1]) || ((dst->ne[0] * SIZEOF_FP32) != dst->nb[1])) {
|
||||
FARF(ERROR, "Non-contiguous tensors are not supported at this time \n");
|
||||
@@ -697,25 +667,20 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
|
||||
|
||||
size_t src0_row_size = src0->nb[1];
|
||||
size_t src1_row_size = src1->nb[1]; // zero bytes if src1 is not used
|
||||
size_t src1_row_size = src1 ? src1->nb[1] : src0->nb[1];
|
||||
size_t dst_row_size = dst->nb[1];
|
||||
|
||||
const bool src1_valid = src1->ne[0];
|
||||
if (!src1_valid) {
|
||||
src1_row_size = src0_row_size;
|
||||
}
|
||||
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
|
||||
size_t spad_size_per_row = (src0_row_size_aligned + src1_row_size_aligned) + dst_row_size_aligned;
|
||||
size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads* spad_size_per_row);
|
||||
|
||||
// Make sure the reserved vtcm size is sufficient
|
||||
if(vtcm_row_per_thread ==0){
|
||||
if (vtcm_row_per_thread == 0) {
|
||||
FARF(ERROR, "act-%s : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", op_type, octx->ctx->vtcm_size,
|
||||
spad_size_per_row * n_threads);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
@@ -733,7 +698,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
|
||||
if (src1->ne[0]) {
|
||||
octx->src0_spad.src = NULL;
|
||||
octx->src1_spad.src = NULL;
|
||||
octx->dst_spad.src = NULL;
|
||||
|
||||
if (src1) {
|
||||
FARF(HIGH, "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
|
||||
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
|
||||
@@ -773,9 +742,9 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
|
||||
// Pointers and GLU logic
|
||||
const uint8_t * data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * data_src1 = (const uint8_t *) src1->data;
|
||||
const uint8_t * data_src1 = src1 ? (const uint8_t *) src1->data : NULL;
|
||||
|
||||
if (!src1_valid && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) {
|
||||
if (!src1 && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) {
|
||||
const int32_t swapped = octx->op_params[1];
|
||||
data_src1 = data_src0;
|
||||
actx.src1_row_size = actx.src0_row_size;
|
||||
@@ -799,7 +768,7 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
int op_activations(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (octx->src0.type) {
|
||||
switch (octx->src[0]->type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = execute_op_activations_f32(octx);
|
||||
break;
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "hex-dma.h"
|
||||
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
#ifndef MIN
|
||||
@@ -175,8 +175,8 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = actx->octx;
|
||||
|
||||
// Unpack context
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
// Scratchpad memory
|
||||
uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i;
|
||||
@@ -249,16 +249,16 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
|
||||
|
||||
int op_argsort(struct htp_ops_context * octx) {
|
||||
// Check supported types
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
if (octx->src[0]->type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
|
||||
const uint32_t total_rows = octx->src[0]->ne[1] * octx->src[0]->ne[2] * octx->src[0]->ne[3];
|
||||
const uint32_t n_threads = MIN(total_rows, octx->n_threads);
|
||||
|
||||
// Allocate scratchpad
|
||||
// We need 1 row of float + 1 row of int32 per thread.
|
||||
uint32_t ne00 = octx->src0.ne[0];
|
||||
uint32_t ne00 = octx->src[0]->ne[0];
|
||||
size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
|
||||
size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128);
|
||||
size_t spad_per_thread = values_size + indices_size;
|
||||
@@ -278,9 +278,9 @@ int op_argsort(struct htp_ops_context * octx) {
|
||||
octx->src0_spad.size_per_thread = spad_per_thread;
|
||||
|
||||
FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)",
|
||||
octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3],
|
||||
octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],
|
||||
octx->src0.data, octx->dst.data);
|
||||
octx->src[0]->ne[0], octx->src[0]->ne[1], octx->src[0]->ne[2], octx->src[0]->ne[3],
|
||||
octx->dst->ne[0], octx->dst->ne[1], octx->dst->ne[2], octx->dst->ne[3],
|
||||
octx->src[0]->data, octx->dst->data);
|
||||
|
||||
struct htp_argsort_context actx;
|
||||
actx.octx = octx;
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
#ifndef MIN
|
||||
@@ -43,10 +43,10 @@ struct htp_binary_context {
|
||||
bool split_at_ne02;
|
||||
};
|
||||
|
||||
#define htp_binary_preamble \
|
||||
const struct htp_tensor * src0 = &octx->src0; \
|
||||
const struct htp_tensor * src1 = &octx->src1; \
|
||||
struct htp_tensor * dst = &octx->dst; \
|
||||
#define htp_binary_preamble \
|
||||
const struct htp_tensor * src0 = octx->src[0]; \
|
||||
const struct htp_tensor * src1 = octx->src[1]; \
|
||||
const struct htp_tensor * dst = octx->dst; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
@@ -181,7 +181,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_ops_context * octx = bctx->octx;
|
||||
htp_binary_preamble;
|
||||
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
const uint32_t src0_type = octx->src[0]->type;
|
||||
const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
|
||||
const uint32_t total_rows = ne01 * ne02 * ne03;
|
||||
const uint32_t start_row = bctx->nrows_per_thread * ith;
|
||||
@@ -274,7 +274,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
|
||||
struct htp_ops_context * octx = bctx->octx;
|
||||
htp_binary_preamble;
|
||||
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
const uint32_t src0_type = octx->src[0]->type;
|
||||
const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
|
||||
const uint32_t total_rows = ne01 * ne02 * ne03;
|
||||
const uint32_t start_row = bctx->nrows_per_thread * ith;
|
||||
@@ -374,7 +374,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
|
||||
struct htp_ops_context * octx = bctx->octx;
|
||||
htp_binary_preamble;
|
||||
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
const uint32_t src0_type = octx->src[0]->type;
|
||||
const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
|
||||
const uint32_t total_rows = ne01 * ne02 * ne03;
|
||||
const uint32_t start_row = bctx->nrows_per_thread * ith;
|
||||
@@ -455,7 +455,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
|
||||
struct htp_ops_context * octx = bctx->octx;
|
||||
htp_binary_preamble;
|
||||
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
const uint32_t src0_type = octx->src[0]->type;
|
||||
const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
|
||||
const uint32_t total_rows = ne01 * ne02 * ne03;
|
||||
const uint32_t start_row = bctx->nrows_per_thread * ith;
|
||||
@@ -540,7 +540,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
|
||||
struct htp_ops_context * octx = bctx->octx;
|
||||
htp_binary_preamble;
|
||||
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
const uint32_t src0_type = octx->src[0]->type;
|
||||
const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
|
||||
const uint32_t row_size_bytes = ne00 * elem_size_bytes;;
|
||||
const uint32_t total_rows = ne01 * ne02 * ne03;
|
||||
@@ -629,10 +629,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_binary_context * bctx = (struct htp_binary_context *) data;
|
||||
struct htp_ops_context * octx = bctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
const struct htp_tensor * src2 = &octx->src2;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
const struct htp_tensor * src2 = octx->src[2];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
const uint32_t ne00 = src0->ne[0];
|
||||
const uint32_t ne01 = src0->ne[1];
|
||||
@@ -723,15 +723,15 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
}
|
||||
|
||||
static int execute_op_binary(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
|
||||
|
||||
// Use packed row sizes for VTCM allocation
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
const uint32_t src0_type = octx->src[0]->type;
|
||||
const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
|
||||
const size_t src0_row_size = src0->ne[0] * elem_size;
|
||||
const size_t src1_row_size = src1->ne[0] * elem_size;
|
||||
@@ -799,9 +799,9 @@ static int execute_op_binary(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL;
|
||||
|
||||
if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
return HTP_STATUS_OK;
|
||||
@@ -857,12 +857,12 @@ static int execute_op_binary(struct htp_ops_context * octx) {
|
||||
int op_binary(struct htp_ops_context * octx) {
|
||||
|
||||
// Does not support permutations of src1
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
if (src1->nb[1] < src1->nb[0]) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
const uint32_t src0_type = octx->src[0]->type;
|
||||
if ((src0_type == HTP_TYPE_F32) || (src0_type == HTP_TYPE_F16)) {
|
||||
return execute_op_binary(octx);
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
@@ -32,10 +32,10 @@ struct htp_copy_context {
|
||||
void (*copy)(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith);
|
||||
};
|
||||
|
||||
#define cpy_preamble \
|
||||
struct htp_tensor *src0 = &octx->src0; \
|
||||
struct htp_tensor *dst = &octx->dst; \
|
||||
\
|
||||
#define cpy_preamble \
|
||||
const struct htp_tensor *src0 = octx->src[0]; \
|
||||
const struct htp_tensor *dst = octx->dst; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
|
||||
@@ -13,9 +13,9 @@
|
||||
#include "hvx-utils.h"
|
||||
#include "hex-dma.h"
|
||||
|
||||
#define htp_cumsum_tensors_preamble \
|
||||
struct htp_tensor * restrict src0 = &octx->src0; \
|
||||
struct htp_tensor * restrict dst = &octx->dst; \
|
||||
#define htp_cumsum_tensors_preamble \
|
||||
const struct htp_tensor * restrict src0 = octx->src[0]; \
|
||||
const struct htp_tensor * restrict dst = octx->dst; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
@@ -206,8 +206,8 @@ static void cumsum_thread_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
}
|
||||
|
||||
int op_cumsum_f32(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
||||
return HTP_STATUS_OK;
|
||||
@@ -226,10 +226,12 @@ int op_cumsum_f32(struct htp_ops_context * octx) {
|
||||
|
||||
octx->src0_spad.size_per_thread = src_row_size_aligned * 2;
|
||||
octx->dst_spad.size_per_thread = dst_row_size_aligned * 2;
|
||||
octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
|
||||
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
|
||||
octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
|
||||
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.src = NULL;
|
||||
|
||||
struct htp_cumsum_context cctx = {
|
||||
.octx = octx,
|
||||
@@ -251,8 +253,9 @@ int op_cumsum_f32(struct htp_ops_context * octx) {
|
||||
}
|
||||
|
||||
int op_cumsum(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (dst->type) {
|
||||
case HTP_TYPE_F32:
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
// Must be multiple of 32
|
||||
@@ -278,12 +278,12 @@ static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t *
|
||||
static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_fa_context * factx = (struct htp_fa_context *) data;
|
||||
const struct htp_ops_context * octx = factx->octx;
|
||||
const struct htp_tensor * q = &octx->src0;
|
||||
const struct htp_tensor * k = &octx->src1;
|
||||
const struct htp_tensor * v = &octx->src2;
|
||||
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
|
||||
const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * q = octx->src[0];
|
||||
const struct htp_tensor * k = octx->src[1];
|
||||
const struct htp_tensor * v = octx->src[2];
|
||||
const struct htp_tensor * mask = octx->src[3];
|
||||
const struct htp_tensor * sinks = octx->src[4];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
const uint32_t neq0 = q->ne[0];
|
||||
const uint32_t neq1 = q->ne[1];
|
||||
@@ -610,11 +610,11 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
|
||||
}
|
||||
|
||||
int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * q = &octx->src0;
|
||||
const struct htp_tensor * k = &octx->src1;
|
||||
const struct htp_tensor * v = &octx->src2;
|
||||
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * q = octx->src[0];
|
||||
const struct htp_tensor * k = octx->src[1];
|
||||
const struct htp_tensor * v = octx->src[2];
|
||||
const struct htp_tensor * mask = octx->src[3];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
// Check support
|
||||
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) {
|
||||
@@ -701,13 +701,11 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
|
||||
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
|
||||
|
||||
// FARF(ERROR, "fa: qrows-per-thread %u", factx.qrows_per_thread);
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL;
|
||||
octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->src2_spad.src = NULL;
|
||||
octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size; octx->src3_spad.src = NULL;
|
||||
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; octx->dst_spad.src = NULL;
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
@@ -23,27 +23,33 @@ struct get_rows_context {
|
||||
};
|
||||
|
||||
#define get_rows_preamble \
|
||||
const uint32_t ne00 = octx->src0.ne[0]; \
|
||||
const uint32_t ne01 = octx->src0.ne[1]; \
|
||||
const uint32_t ne02 = octx->src0.ne[2]; \
|
||||
const uint32_t ne03 = octx->src0.ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = octx->src1.ne[0]; \
|
||||
const uint32_t ne11 = octx->src1.ne[1]; \
|
||||
const uint32_t ne12 = octx->src1.ne[2]; \
|
||||
\
|
||||
const uint32_t nb01 = octx->src0.nb[1]; \
|
||||
const uint32_t nb02 = octx->src0.nb[2]; \
|
||||
const uint32_t nb03 = octx->src0.nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = octx->src1.nb[0]; \
|
||||
const uint32_t nb11 = octx->src1.nb[1]; \
|
||||
const uint32_t nb12 = octx->src1.nb[2]; \
|
||||
\
|
||||
const uint32_t nb1 = octx->dst.nb[1]; \
|
||||
const uint32_t nb2 = octx->dst.nb[2]; \
|
||||
const uint32_t nb3 = octx->dst.nb[3]; \
|
||||
\
|
||||
const uint32_t ne00 = octx->src[0]->ne[0]; \
|
||||
const uint32_t ne01 = octx->src[0]->ne[1]; \
|
||||
const uint32_t ne02 = octx->src[0]->ne[2]; \
|
||||
const uint32_t ne03 = octx->src[0]->ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = octx->src[1]->ne[0]; \
|
||||
const uint32_t ne11 = octx->src[1]->ne[1]; \
|
||||
const uint32_t ne12 = octx->src[1]->ne[2]; \
|
||||
const uint32_t ne13 = octx->src[1]->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = octx->dst->ne[0]; \
|
||||
const uint32_t ne1 = octx->dst->ne[1]; \
|
||||
const uint32_t ne2 = octx->dst->ne[2]; \
|
||||
const uint32_t ne3 = octx->dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb01 = octx->src[0]->nb[1]; \
|
||||
const uint32_t nb02 = octx->src[0]->nb[2]; \
|
||||
const uint32_t nb03 = octx->src[0]->nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = octx->src[1]->nb[0]; \
|
||||
const uint32_t nb11 = octx->src[1]->nb[1]; \
|
||||
const uint32_t nb12 = octx->src[1]->nb[2]; \
|
||||
\
|
||||
const uint32_t nb1 = octx->dst->nb[1]; \
|
||||
const uint32_t nb2 = octx->dst->nb[2]; \
|
||||
const uint32_t nb3 = octx->dst->nb[3]; \
|
||||
\
|
||||
const uint32_t nr = ne10 * ne11 * ne12;
|
||||
|
||||
static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
|
||||
@@ -51,12 +57,14 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da
|
||||
struct htp_ops_context * octx = grctx->octx;
|
||||
get_rows_preamble;
|
||||
|
||||
uint64_t qt = HAP_perf_get_qtimer_count();
|
||||
|
||||
// parallelize by src1 elements (which correspond to dst rows)
|
||||
const uint32_t dr = grctx->src1_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
|
||||
const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32);
|
||||
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11);
|
||||
@@ -64,7 +72,7 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da
|
||||
const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10);
|
||||
const uint32_t i10 = rem - i11 * ne10;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
||||
uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
|
||||
|
||||
@@ -73,10 +81,14 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da
|
||||
continue;
|
||||
}
|
||||
|
||||
const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3;
|
||||
const uintptr_t src0_ptr = octx->src[0]->data + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
const uintptr_t dst_ptr = octx->dst->data + i10*nb1 + i11*nb2 + i12*nb3;
|
||||
hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
|
||||
}
|
||||
|
||||
qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt);
|
||||
FARF(HIGH, "get-rows-f32-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt);
|
||||
}
|
||||
|
||||
int op_get_rows(struct htp_ops_context * octx) {
|
||||
@@ -84,15 +96,15 @@ int op_get_rows(struct htp_ops_context * octx) {
|
||||
|
||||
const uint32_t n_threads = MIN(nr, octx->n_threads);
|
||||
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
if (octx->src[0]->type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->dst.type != HTP_TYPE_F32) {
|
||||
if (octx->dst->type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
|
||||
if (octx->src[1]->type != HTP_TYPE_I32 && octx->src[1]->type != HTP_TYPE_I64) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
@@ -102,8 +114,8 @@ int op_get_rows(struct htp_ops_context * octx) {
|
||||
|
||||
struct get_rows_context grctx;
|
||||
grctx.octx = octx;
|
||||
grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
|
||||
grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
|
||||
grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src[1]->ne[0]);
|
||||
grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src[1]->ne[0] * octx->src[1]->ne[1]);
|
||||
|
||||
grctx.src1_nrows_per_thread = (nr + n_threads - 1) / n_threads;
|
||||
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <qurt_memory.h>
|
||||
|
||||
#include "hexagon_types.h"
|
||||
#include "hexagon_protos.h"
|
||||
|
||||
#include "hex-fastdiv.h"
|
||||
#include "hex-dump.h"
|
||||
@@ -68,4 +70,23 @@ static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride,
|
||||
Q6_l2fetch_AP((void *) p, control);
|
||||
}
|
||||
|
||||
#define HEX_L2_LINE_SIZE 64
|
||||
#define HEX_L2_FLUSH_SIZE (128 * 1024)
|
||||
|
||||
static inline void hex_l2flush(void * addr, size_t size)
|
||||
{
|
||||
if (size > HEX_L2_FLUSH_SIZE) {
|
||||
qurt_mem_cache_clean((qurt_addr_t) 0, 0, QURT_MEM_CACHE_FLUSH_INVALIDATE_ALL, QURT_MEM_DCACHE);
|
||||
} else {
|
||||
const uint32_t s = (uint32_t) addr;
|
||||
const uint32_t e = s + size;
|
||||
for (uint32_t i = s; i < e; i += HEX_L2_LINE_SIZE * 4) {
|
||||
Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 0);
|
||||
Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 1);
|
||||
Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 2);
|
||||
Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 3);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* HEX_UTILS_H */
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
#include "hvx-dump.h"
|
||||
#include "worker-pool.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
#include "hmx-utils.h"
|
||||
#include "hmx-ops.h"
|
||||
@@ -821,7 +821,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
|
||||
// and each q_head is computed individually to avoid tile-major packing
|
||||
// issues. m_chunk_n_rows is always a multiple of 32 (from
|
||||
// hmx_compute_chunks), so per-head tile arrays don't overlap.
|
||||
const size_t vtcm_budget = ctx->vtcm_scratch_size;
|
||||
const size_t vtcm_budget = ctx->vtcm_size;
|
||||
const size_t vec_dot_size = params->k * sizeof(__fp16);
|
||||
|
||||
// When the activation has a large stride (e.g. permuted Q tensor with
|
||||
@@ -998,7 +998,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
|
||||
}
|
||||
|
||||
// --- Dynamic VTCM layout ---
|
||||
const size_t vtcm_budget = ctx->vtcm_scratch_size;
|
||||
const size_t vtcm_budget = ctx->vtcm_size;
|
||||
const size_t vec_dot_size = k * sizeof(__fp16);
|
||||
|
||||
// DMA-based activation gather for strided tensors (see batched path comment).
|
||||
@@ -1182,7 +1182,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
|
||||
FARF(MEDIUM, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type);
|
||||
|
||||
// --- Dynamic VTCM layout ---
|
||||
const size_t vtcm_budget = ctx->vtcm_scratch_size;
|
||||
const size_t vtcm_budget = ctx->vtcm_size;
|
||||
const size_t vec_dot_size = k * sizeof(__fp16);
|
||||
const bool use_pipeline = (m >= 128) && (k <= n);
|
||||
|
||||
@@ -1273,9 +1273,6 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
|
||||
void *buf_curr = vtcm_scratch0;
|
||||
void *buf_next = vtcm_scratch1;
|
||||
|
||||
// issue async DDR data transfer for the first weight chunk
|
||||
// NOTE: use 2D DMA (n_cols rows x row_stride bytes) instead of 1D
|
||||
// because UDMA roiwidth is 16-bit and total size can exceed 65535.
|
||||
{
|
||||
const size_t n_cols_first = hex_smin(n, n_chunk_n_cols);
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first);
|
||||
@@ -1533,20 +1530,15 @@ void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, co
|
||||
worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads);
|
||||
}
|
||||
|
||||
int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, int m,
|
||||
int k, int n, int weight_type) {
|
||||
// Runtime check -- k >= 16384 exceeds 2D DMA limit
|
||||
if (k >= 16384) {
|
||||
FARF(HIGH, "%s: k=%d exceeds 2D DMA limit", __func__, k);
|
||||
return -1;
|
||||
}
|
||||
int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w,
|
||||
int m, int k, int n, int weight_type) {
|
||||
// assume k % 32 == 0 && n % 32 == 0
|
||||
const size_t row_stride = get_x4x2_row_stride(weight_type, k);
|
||||
if (row_stride == 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
const size_t vtcm_budget = ctx->vtcm_scratch_size;
|
||||
const size_t vtcm_budget = ctx->vtcm_size;
|
||||
|
||||
const size_t M_BLOCK_SIZE = 512;
|
||||
const size_t N_BLOCK_SIZE = 512;
|
||||
@@ -1576,8 +1568,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
|
||||
__fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256);
|
||||
assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget);
|
||||
|
||||
FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d vtcm=%zu/%zu",
|
||||
__func__, m, k, n, weight_type,
|
||||
FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d vtcm=%zu/%zu", __func__, m, k, n, weight_type,
|
||||
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
|
||||
|
||||
// initialize eye tile (32x32 identity matrix)
|
||||
|
||||
@@ -7,16 +7,12 @@
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#ifndef restrict
|
||||
# define restrict __restrict
|
||||
#endif
|
||||
#include "htp-ops.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct htp_context; // forward declaration
|
||||
|
||||
typedef struct {
|
||||
float *dst;
|
||||
const float *activation;
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define HTP_CTX_H
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "htp-ops.h"
|
||||
#include "worker-pool.h"
|
||||
|
||||
#include <assert.h>
|
||||
@@ -10,38 +11,85 @@
|
||||
#include <stdint.h>
|
||||
|
||||
#define HTP_MAX_NTHREADS 10
|
||||
#define HTP_MAX_MMAPS 16
|
||||
|
||||
// Memory mapping
|
||||
struct htp_mmap {
|
||||
uint64_t size;
|
||||
uint64_t base;
|
||||
uint32_t fd;
|
||||
uint32_t pinned;
|
||||
};
|
||||
|
||||
// Scratchpad state
|
||||
struct htp_spad {
|
||||
const struct htp_tensor * src; // original src of the data (for reuse)
|
||||
uint8_t * data; // pointer to an area in vtcm
|
||||
uint32_t stride; // stride used inside this spad
|
||||
uint32_t size; // total size
|
||||
uint32_t size_per_thread; // size per thread
|
||||
};
|
||||
|
||||
// Context while processing an Op
|
||||
// TODO: fold this into the main context
|
||||
struct htp_ops_context {
|
||||
struct htp_context * ctx;
|
||||
|
||||
enum htp_op_code op; // FIXME: rename to opcode
|
||||
int32_t op_params[HTP_OP_MAX_PARAMS];
|
||||
|
||||
const struct htp_tensor * src[HTP_OP_MAX_INPUTS];
|
||||
const struct htp_tensor * dst;
|
||||
|
||||
// TODO convert these to an array
|
||||
struct htp_spad src0_spad;
|
||||
struct htp_spad src1_spad;
|
||||
struct htp_spad src2_spad;
|
||||
struct htp_spad src3_spad;
|
||||
struct htp_spad dst_spad;
|
||||
|
||||
uint32_t n_threads;
|
||||
uint32_t flags;
|
||||
};
|
||||
|
||||
// Main context for htp DSP backend
|
||||
struct htp_context {
|
||||
dspqueue_t queue;
|
||||
dma_queue * dma[HTP_MAX_NTHREADS];
|
||||
worker_pool_context_t worker_pool;
|
||||
uint32_t n_threads;
|
||||
dspqueue_t queue;
|
||||
dma_queue * dma[HTP_MAX_NTHREADS];
|
||||
struct htp_mmap mmap[HTP_MAX_MMAPS];
|
||||
worker_pool_context_t worker_pool;
|
||||
uint32_t n_threads;
|
||||
|
||||
int thread_id;
|
||||
int thread_prio;
|
||||
int thread_id;
|
||||
int thread_prio;
|
||||
|
||||
uint8_t * vtcm_base;
|
||||
size_t vtcm_size;
|
||||
uint32_t vtcm_rctx;
|
||||
int hmx_enabled;
|
||||
|
||||
atomic_bool vtcm_valid;
|
||||
atomic_bool vtcm_inuse;
|
||||
atomic_bool vtcm_needs_release;
|
||||
uint8_t * vtcm_base;
|
||||
size_t vtcm_size;
|
||||
uint32_t vtcm_rctx;
|
||||
atomic_bool vtcm_valid;
|
||||
atomic_bool vtcm_needs_release;
|
||||
|
||||
uint32_t opmask;
|
||||
|
||||
// Cached src1 spad position from the last quantize pass.
|
||||
// When SKIP_QUANTIZE is set the Q8 activation data is already in VTCM
|
||||
// at this address; the matmul must read from here instead of recomputing
|
||||
// the offset (which depends on the current op's src0 size).
|
||||
uint8_t * prev_src1_spad;
|
||||
|
||||
// HMX acceleration fields (v73+, enabled by compile-time HTP_HAS_HMX)
|
||||
#ifdef HTP_HAS_HMX
|
||||
int hmx_enabled; // Runtime flag: HMX initialisation succeeded
|
||||
size_t vtcm_scratch_size; // Usable dynamic scratch (vtcm_size minus tail reservation)
|
||||
#endif
|
||||
struct htp_ops_context octx;
|
||||
};
|
||||
|
||||
int op_matmul(struct htp_ops_context * octx);
|
||||
int op_matmul_id(struct htp_ops_context * octx);
|
||||
int op_binary(struct htp_ops_context * octx);
|
||||
int op_unary(struct htp_ops_context * octx);
|
||||
int op_sum_rows(struct htp_ops_context * octx);
|
||||
int op_activations(struct htp_ops_context * octx);
|
||||
int op_softmax(struct htp_ops_context * octx);
|
||||
int op_add_id(struct htp_ops_context * octx);
|
||||
int op_rope(struct htp_ops_context * octx);
|
||||
int op_flash_attn_ext(struct htp_ops_context * octx);
|
||||
int op_set_rows(struct htp_ops_context * octx);
|
||||
int op_get_rows(struct htp_ops_context * octx);
|
||||
int op_cpy(struct htp_ops_context * octx);
|
||||
int op_repeat(struct htp_ops_context * octx);
|
||||
int op_argsort(struct htp_ops_context * octx);
|
||||
int op_ssm_conv(struct htp_ops_context * octx);
|
||||
int op_cumsum(struct htp_ops_context * octx);
|
||||
|
||||
#endif /* HTP_CTX_H */
|
||||
|
||||
@@ -1,166 +0,0 @@
|
||||
#ifndef HTP_MSG_H
|
||||
#define HTP_MSG_H
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
// ggml-common.h must be included prio to this header
|
||||
|
||||
// Mask to enable various stages of the Ops.
|
||||
// Used for debugging and profiling.
|
||||
enum {
|
||||
HTP_OPMASK_QUEUE = (1 << 0), // Enable Queueing (ie calls into the DSP)
|
||||
HTP_OPMASK_QUANTIZE = (1 << 1), // Enable Quantize
|
||||
HTP_OPMASK_COMPUTE = (1 << 2), // Enable Compute
|
||||
};
|
||||
|
||||
// Op flags
|
||||
enum {
|
||||
HTP_OPFLAGS_SKIP_QUANTIZE = (1 << 0), // Skip dynamic quantization (reuse quantized tensors)
|
||||
HTP_OPFLAGS_SKIP_COMPUTE = (1 << 1), // Skip actual computation (used for profiling)
|
||||
HTP_OPFLAGS_EARLY_WAKEUP = (1 << 2) // Send early wakeup notification
|
||||
};
|
||||
|
||||
enum htp_status {
|
||||
HTP_STATUS_OK = 1,
|
||||
HTP_STATUS_INTERNAL_ERR = 2,
|
||||
HTP_STATUS_NO_SUPPORT = 3,
|
||||
HTP_STATUS_INVAL_PARAMS = 4,
|
||||
HTP_STATUS_VTCM_TOO_SMALL = 5,
|
||||
};
|
||||
|
||||
// The values must match the ggml_type.
|
||||
// Duplicated here because we can't include full ggml.h in the htp build.
|
||||
// We have some static_asserts in the cpp code to ensure things are in sync.
|
||||
enum htp_data_type {
|
||||
HTP_TYPE_F32 = 0,
|
||||
HTP_TYPE_F16 = 1,
|
||||
HTP_TYPE_Q4_0 = 2,
|
||||
HTP_TYPE_Q8_0 = 8,
|
||||
HTP_TYPE_IQ4_NL = 20,
|
||||
HTP_TYPE_I32 = 26,
|
||||
HTP_TYPE_I64 = 27,
|
||||
HTP_TYPE_MXFP4 = 39,
|
||||
HTP_TYPE_COUNT
|
||||
};
|
||||
|
||||
// Do not reorder first 4 (used as an index)
|
||||
enum htp_op {
|
||||
HTP_OP_MUL = 0,
|
||||
HTP_OP_ADD = 1,
|
||||
HTP_OP_SUB = 2,
|
||||
HTP_OP_DIV = 3,
|
||||
HTP_OP_MUL_MAT,
|
||||
HTP_OP_MUL_MAT_ID,
|
||||
HTP_OP_RMS_NORM,
|
||||
HTP_OP_UNARY_SILU,
|
||||
HTP_OP_UNARY_GELU,
|
||||
HTP_OP_UNARY_SIGMOID,
|
||||
HTP_OP_UNARY_EXP,
|
||||
HTP_OP_UNARY_NEG,
|
||||
HTP_OP_UNARY_SOFTPLUS,
|
||||
HTP_OP_GLU_SWIGLU,
|
||||
HTP_OP_GLU_SWIGLU_OAI,
|
||||
HTP_OP_GLU_GEGLU,
|
||||
HTP_OP_SOFTMAX,
|
||||
HTP_OP_ADD_ID,
|
||||
HTP_OP_ROPE,
|
||||
HTP_OP_FLASH_ATTN_EXT,
|
||||
HTP_OP_SET_ROWS,
|
||||
HTP_OP_GET_ROWS,
|
||||
HTP_OP_SCALE,
|
||||
HTP_OP_CPY,
|
||||
HTP_OP_ARGSORT,
|
||||
HTP_OP_SQR,
|
||||
HTP_OP_SQRT,
|
||||
HTP_OP_SUM_ROWS,
|
||||
HTP_OP_SSM_CONV,
|
||||
HTP_OP_REPEAT,
|
||||
HTP_OP_CUMSUM,
|
||||
INVALID
|
||||
};
|
||||
|
||||
static inline size_t htp_t_block_size(uint32_t t) {
|
||||
switch (t) {
|
||||
case HTP_TYPE_F32:
|
||||
return 1;
|
||||
case HTP_TYPE_F16:
|
||||
return 1;
|
||||
case HTP_TYPE_Q4_0:
|
||||
return QK4_0;
|
||||
case HTP_TYPE_Q8_0:
|
||||
return QK8_0;
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
return QK4_NL;
|
||||
case HTP_TYPE_MXFP4:
|
||||
return QK_MXFP4;
|
||||
default:
|
||||
assert(0 && "unsupported HTP data type");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static inline size_t htp_type_nbytes(uint32_t t) {
|
||||
switch (t) {
|
||||
case HTP_TYPE_F32:
|
||||
return 4;
|
||||
case HTP_TYPE_F16:
|
||||
return 2;
|
||||
case HTP_TYPE_Q4_0:
|
||||
return sizeof(block_q4_0);
|
||||
case HTP_TYPE_Q8_0:
|
||||
return sizeof(block_q8_0);
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
return sizeof(block_iq4_nl);
|
||||
case HTP_TYPE_MXFP4:
|
||||
return sizeof(block_mxfp4);
|
||||
default:
|
||||
assert(0 && "unsupported HTP data type");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Internal types
|
||||
#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
|
||||
#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
|
||||
#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks
|
||||
|
||||
#define HTP_MAX_DIMS 4
|
||||
|
||||
struct htp_tensor {
|
||||
uint32_t data; // Buffer offset in the messages, and data pointer on the NSP
|
||||
uint32_t type; // Data type
|
||||
uint32_t ne[HTP_MAX_DIMS]; // Number of elements
|
||||
uint32_t nb[HTP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor)
|
||||
};
|
||||
|
||||
#define HTP_MAX_OP_PARAMS 64
|
||||
|
||||
struct htp_general_req {
|
||||
uint32_t op; // GGML/HTP Op
|
||||
int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||
// Params for the op, e.g. epsilon of RMS norm
|
||||
uint32_t flags; // Request flags
|
||||
|
||||
struct htp_tensor src0; // Input0 tensor
|
||||
struct htp_tensor src1; // Input1 tensor
|
||||
struct htp_tensor src2; // Input2 tensor
|
||||
struct htp_tensor src3; // Input3 tensor
|
||||
struct htp_tensor src4; // Input4 tensor
|
||||
struct htp_tensor dst; // Output tensor
|
||||
|
||||
// should be multiple of 64 bytes (cacheline)
|
||||
};
|
||||
|
||||
struct htp_general_rsp {
|
||||
uint32_t op; // GGML/HTP Op
|
||||
uint32_t status; // HTP_STATUS_...
|
||||
uint32_t prof_usecs; // Number of usec per request
|
||||
uint32_t prof_cycles; // Number of cycles per request
|
||||
uint32_t prof_pkts; // Number of instruction packets per request
|
||||
uint8_t unused[44]; // Pad to 64 bytes
|
||||
};
|
||||
|
||||
#define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req)
|
||||
#define HTP_MAX_PACKET_BUFFERS 8
|
||||
|
||||
#endif /* HTP_MSG_H */
|
||||
@@ -1,65 +1,154 @@
|
||||
#ifndef HTP_OPS_H
|
||||
#define HTP_OPS_H
|
||||
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "worker-pool.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <hex-fastdiv.h>
|
||||
// ggml-common.h must be included prio to this header
|
||||
|
||||
// ggml-common.h must be included prior to this header
|
||||
|
||||
struct htp_spad {
|
||||
uint8_t * data;
|
||||
size_t stride;
|
||||
size_t size;
|
||||
size_t size_per_thread;
|
||||
enum htp_status {
|
||||
HTP_STATUS_OK = 1,
|
||||
HTP_STATUS_INTERNAL_ERR = 2,
|
||||
HTP_STATUS_NO_SUPPORT = 3,
|
||||
HTP_STATUS_INVAL_PARAMS = 4,
|
||||
HTP_STATUS_VTCM_TOO_SMALL = 5,
|
||||
};
|
||||
|
||||
struct htp_ops_context {
|
||||
struct htp_context * ctx;
|
||||
// First set of values must match the ggml_type.
|
||||
// Duplicated here because we can't include full ggml.h in the htp build.
|
||||
// We have some static_asserts in the cpp code to ensure things are in sync.
|
||||
enum htp_data_type {
|
||||
HTP_TYPE_F32 = 0,
|
||||
HTP_TYPE_F16 = 1,
|
||||
HTP_TYPE_Q4_0 = 2,
|
||||
HTP_TYPE_Q8_0 = 8,
|
||||
HTP_TYPE_IQ4_NL = 20,
|
||||
HTP_TYPE_I32 = 26,
|
||||
HTP_TYPE_I64 = 27,
|
||||
HTP_TYPE_MXFP4 = 39,
|
||||
|
||||
enum htp_op op;
|
||||
int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||
// types used internally for repack, dyn.quant, etc
|
||||
HTP_TYPE_Q4_0x4x2 = 200,
|
||||
HTP_TYPE_Q8_0x4x2,
|
||||
HTP_TYPE_MXFP4x4x2,
|
||||
|
||||
struct htp_tensor src0;
|
||||
struct htp_tensor src1;
|
||||
struct htp_tensor src2;
|
||||
struct htp_tensor src3;
|
||||
struct htp_tensor src4;
|
||||
struct htp_tensor dst;
|
||||
|
||||
struct htp_spad src0_spad;
|
||||
struct htp_spad src1_spad;
|
||||
struct htp_spad src2_spad;
|
||||
struct htp_spad src3_spad;
|
||||
struct htp_spad dst_spad;
|
||||
|
||||
worker_pool_context_t * wpool; // worker pool
|
||||
uint32_t n_threads; // num threads
|
||||
|
||||
uint32_t flags;
|
||||
HTP_TYPE_INVALID
|
||||
};
|
||||
|
||||
int op_matmul(struct htp_ops_context * octx);
|
||||
int op_matmul_id(struct htp_ops_context * octx);
|
||||
int op_binary(struct htp_ops_context * octx);
|
||||
int op_unary(struct htp_ops_context * octx);
|
||||
int op_sum_rows(struct htp_ops_context * octx);
|
||||
int op_activations(struct htp_ops_context * octx);
|
||||
int op_softmax(struct htp_ops_context * octx);
|
||||
int op_add_id(struct htp_ops_context * octx);
|
||||
int op_rope(struct htp_ops_context * octx);
|
||||
int op_flash_attn_ext(struct htp_ops_context * octx);
|
||||
int op_set_rows(struct htp_ops_context * octx);
|
||||
int op_get_rows(struct htp_ops_context * octx);
|
||||
int op_cpy(struct htp_ops_context * octx);
|
||||
int op_repeat(struct htp_ops_context * octx);
|
||||
int op_argsort(struct htp_ops_context * octx);
|
||||
int op_ssm_conv(struct htp_ops_context * octx);
|
||||
int op_cumsum(struct htp_ops_context * octx);
|
||||
// Constats for internal types
|
||||
#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
|
||||
#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
|
||||
#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks
|
||||
|
||||
|
||||
// Mask to enable various stages of the Ops.
|
||||
// Used for debugging and profiling.
|
||||
enum htp_op_mask {
|
||||
HTP_OPMASK_QUEUE = (1 << 0), // Enable Queueing (ie calls into the DSP)
|
||||
HTP_OPMASK_COMPUTE = (1 << 1), // Enable Compute
|
||||
};
|
||||
|
||||
// Do not reorder first 4 (used as an index)
|
||||
enum htp_op_code {
|
||||
HTP_OP_MUL = 0,
|
||||
HTP_OP_ADD = 1,
|
||||
HTP_OP_SUB = 2,
|
||||
HTP_OP_DIV = 3,
|
||||
HTP_OP_MUL_MAT,
|
||||
HTP_OP_MUL_MAT_ID,
|
||||
HTP_OP_RMS_NORM,
|
||||
HTP_OP_UNARY_SILU,
|
||||
HTP_OP_UNARY_GELU,
|
||||
HTP_OP_UNARY_SIGMOID,
|
||||
HTP_OP_UNARY_EXP,
|
||||
HTP_OP_UNARY_NEG,
|
||||
HTP_OP_UNARY_SOFTPLUS,
|
||||
HTP_OP_GLU_SWIGLU,
|
||||
HTP_OP_GLU_SWIGLU_OAI,
|
||||
HTP_OP_GLU_GEGLU,
|
||||
HTP_OP_SOFTMAX,
|
||||
HTP_OP_ADD_ID,
|
||||
HTP_OP_ROPE,
|
||||
HTP_OP_FLASH_ATTN_EXT,
|
||||
HTP_OP_SET_ROWS,
|
||||
HTP_OP_GET_ROWS,
|
||||
HTP_OP_SCALE,
|
||||
HTP_OP_CPY,
|
||||
HTP_OP_ARGSORT,
|
||||
HTP_OP_SQR,
|
||||
HTP_OP_SQRT,
|
||||
HTP_OP_SUM_ROWS,
|
||||
HTP_OP_SSM_CONV,
|
||||
HTP_OP_REPEAT,
|
||||
HTP_OP_CUMSUM,
|
||||
|
||||
HTP_OP_INVALID
|
||||
};
|
||||
|
||||
#define HTP_OP_MAX_DIMS 4 // aka GGML_MAX_DIMS
|
||||
#define HTP_OP_MAX_INPUTS 6 // aka GGML_MAX_SRCS
|
||||
#define HTP_OP_MAX_PARAMS 16 // aka GGML_MAX_OP_PARAMS
|
||||
|
||||
#define HTP_OP_MAX_BUFS 8
|
||||
#define HTP_OP_MAX_REQS 256
|
||||
#define HTP_OP_MAX_TENSORS (HTP_OP_MAX_REQS * HTP_OP_MAX_INPUTS + HTP_OP_MAX_REQS)
|
||||
#define HTP_OP_MAX_VMEM (3221225472u)
|
||||
|
||||
enum htp_tensor_flags {
|
||||
HTP_TENSOR_COMPUTE = (1U << 0), // Tensor buffer temporal compute data (not weights)
|
||||
HTP_TENSOR_FLUSHED = (1U << 1) // Tensor buffer has been flushed (set by the NPU)
|
||||
};
|
||||
|
||||
// Tensor descriptor
|
||||
struct htp_tensor {
|
||||
uint32_t data; // Buffer offset in the messages, and data pointer on the NPU
|
||||
uint32_t size; // Data size in bytes
|
||||
uint32_t flags; // Buffer / tensor flags
|
||||
uint16_t type; // Data type
|
||||
uint16_t bi; // Buffer index
|
||||
uint32_t ne[HTP_OP_MAX_DIMS]; // Number of elements
|
||||
uint32_t nb[HTP_OP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor)
|
||||
};
|
||||
|
||||
// Buffer descriptor
|
||||
struct htp_buf_desc {
|
||||
uint64_t base; // base address
|
||||
uint64_t size; // total size
|
||||
uint32_t flags; // buffer flags (unused)
|
||||
uint32_t fd; // file descriptor
|
||||
};
|
||||
|
||||
enum htp_op_flags {
|
||||
HTP_OPFLAGS_SKIP_COMPUTE = (1U << 0), // Skip actual computation (used for profiling)
|
||||
};
|
||||
|
||||
// Op descriptor
|
||||
struct htp_op_desc {
|
||||
uint32_t opcode; // GGML/HTP Op
|
||||
uint32_t flags; // Op flags
|
||||
int32_t params[HTP_OP_MAX_PARAMS]; // Params for the op, e.g. epsilon of RMS norm
|
||||
uint16_t src[HTP_OP_MAX_INPUTS]; // Input tensors indices
|
||||
uint16_t dst; // Output tensor index
|
||||
|
||||
// the rest is filled in-place by the NPU
|
||||
uint32_t prof_usecs; // Number of usec per request
|
||||
uint32_t prof_cycles; // Number of cycles per request
|
||||
uint32_t prof_pkts; // Number of instruction packets per request
|
||||
uint32_t unused;
|
||||
};
|
||||
|
||||
struct htp_opbatch_req {
|
||||
uint32_t n_bufs; // Number of buffers
|
||||
uint32_t n_tensors; // Number of tensors
|
||||
uint32_t n_ops; // Number of ops
|
||||
uint32_t flags; // unused
|
||||
// struct htp_buf_desc bufs[]; -- dspqueue buf 0
|
||||
// struct htp_tensor tensors[]; -- dspqueue buf 0
|
||||
// struct htp_op_desc ops[]; -- dspqueue buf 0
|
||||
};
|
||||
|
||||
struct htp_opbatch_rsp {
|
||||
uint32_t status; // HTP_STATUS_...
|
||||
// struct htp_op_req ops[]; -- dspqueue buf 0
|
||||
};
|
||||
|
||||
#endif /* HTP_OPS_H */
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
interface htp_iface : remote_handle64 {
|
||||
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx);
|
||||
AEEResult stop();
|
||||
AEEResult mmap(in uint32 fd, in uint32 size, in uint32 pinned);
|
||||
AEEResult munmap(in uint32 fd);
|
||||
AEEResult enable_etm();
|
||||
AEEResult disable_etm();
|
||||
};
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -16,8 +16,9 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hmx-ops.h"
|
||||
|
||||
#define MM_SPAD_SRC0_NROWS 16
|
||||
#define MM_SPAD_SRC1_NROWS 16
|
||||
@@ -1897,11 +1898,11 @@ static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void *
|
||||
hvx_vec_store_u(&s[0], 4, rsum);
|
||||
}
|
||||
|
||||
#define htp_matmul_tensors_preamble \
|
||||
struct htp_tensor * restrict src0 = &octx->src0; \
|
||||
struct htp_tensor * restrict src1 = &octx->src1; \
|
||||
struct htp_tensor * restrict src2 = &octx->src2; \
|
||||
struct htp_tensor * restrict dst = &octx->dst; \
|
||||
#define htp_matmul_tensors_preamble \
|
||||
const struct htp_tensor * restrict src0 = octx->src[0]; \
|
||||
const struct htp_tensor * restrict src1 = octx->src[1]; \
|
||||
const struct htp_tensor * restrict src2 = octx->src[2]; \
|
||||
const struct htp_tensor * restrict dst = octx->dst; \
|
||||
struct htp_spad * restrict src0_spad = &octx->src0_spad; \
|
||||
struct htp_spad * restrict src1_spad = &octx->src1_spad; \
|
||||
struct htp_spad * restrict dst_spad = &octx->dst_spad; \
|
||||
@@ -2223,8 +2224,8 @@ struct mmid_row_mapping {
|
||||
static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
htp_matmul_preamble;
|
||||
|
||||
struct htp_tensor * restrict ids = &octx->src2;
|
||||
struct htp_spad * restrict src2_spad = &octx->src2_spad;
|
||||
const struct htp_tensor * restrict ids = octx->src[2];
|
||||
struct htp_spad * restrict src2_spad = &octx->src2_spad;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
@@ -2342,8 +2343,8 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
htp_matmul_preamble;
|
||||
|
||||
struct htp_tensor * restrict ids = &octx->src2;
|
||||
struct htp_spad * restrict src2_spad = &octx->src2_spad;
|
||||
const struct htp_tensor * restrict ids = octx->src[2];
|
||||
struct htp_spad * restrict src2_spad = &octx->src2_spad;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
@@ -2612,7 +2613,7 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data)
|
||||
struct htp_matmul_context * mmctx = data;
|
||||
struct htp_ops_context * octx = mmctx->octx;
|
||||
|
||||
const struct htp_tensor * src = &octx->src1;
|
||||
const struct htp_tensor * src = octx->src[1];
|
||||
uint8_t * restrict dst = octx->src1_spad.data;
|
||||
struct htp_spad * spad = &octx->src0_spad;
|
||||
uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
|
||||
@@ -2659,7 +2660,7 @@ static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_matmul_context * mmctx = data;
|
||||
struct htp_ops_context * octx = mmctx->octx;
|
||||
|
||||
const struct htp_tensor * src = &octx->src1;
|
||||
const struct htp_tensor * src = octx->src[1];
|
||||
uint8_t * restrict dst = octx->src1_spad.data;
|
||||
uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
|
||||
uint32_t dst_stride = octx->src1_spad.stride;
|
||||
@@ -2701,7 +2702,7 @@ static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_matmul_context * mmctx = data;
|
||||
struct htp_ops_context * octx = mmctx->octx;
|
||||
|
||||
const struct htp_tensor * src = &octx->src1;
|
||||
const struct htp_tensor * src = octx->src[1];
|
||||
uint8_t * restrict dst = octx->src1_spad.data;
|
||||
uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
|
||||
uint32_t dst_stride = octx->src1_spad.stride;
|
||||
@@ -2800,7 +2801,7 @@ static void htp_mminit_spad(struct htp_ops_context * octx,
|
||||
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
|
||||
}
|
||||
|
||||
int op_matmul(struct htp_ops_context * octx) {
|
||||
static int op_matmul_hvx(struct htp_ops_context * octx) {
|
||||
htp_matmul_tensors_preamble;
|
||||
|
||||
struct htp_matmul_context mmctx_struct = {0};
|
||||
@@ -2824,7 +2825,7 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
worker_callback_t quant_job_func;
|
||||
worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d;
|
||||
|
||||
bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE);
|
||||
bool need_quant = true;
|
||||
|
||||
if (src0->type == HTP_TYPE_F16) {
|
||||
// Try optimized f16-f16 path first (src1 in VTCM)
|
||||
@@ -2838,7 +2839,7 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
// Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
|
||||
// It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
|
||||
const bool is_batched = (ne02 > 1) || (ne03 > 1);
|
||||
const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);
|
||||
const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]);
|
||||
|
||||
if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
|
||||
// Optimized path
|
||||
@@ -2915,34 +2916,172 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
// Place src1 spad first. We use it for dyn.quant and may reuse between ops
|
||||
octx->src1_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
|
||||
octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL;
|
||||
octx->src0_spad.src = NULL;
|
||||
octx->dst_spad.src = NULL;
|
||||
|
||||
octx->src0_spad.stride = src0_row_size_padded;
|
||||
octx->src1_spad.stride = src1_row_size;
|
||||
|
||||
if (need_quant) {
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)
|
||||
return HTP_STATUS_OK;
|
||||
|
||||
if (need_quant && !octx->src1_spad.src) {
|
||||
const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
|
||||
mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
|
||||
// Cache where src1 was written so subsequent SKIP_QUANTIZE ops can find it
|
||||
octx->ctx->prev_src1_spad = octx->src1_spad.data;
|
||||
} else {
|
||||
// SKIP_QUANTIZE: Q8 data lives at the address written by the previous
|
||||
// quantize pass. The current op may have a different src0 size (e.g.
|
||||
// IQ4_NL vs MXFP4), so src1_spad.data computed above could be wrong.
|
||||
octx->src1_spad.data = octx->ctx->prev_src1_spad;
|
||||
octx->src1_spad.src = src1;
|
||||
}
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
const uint32_t n_matmul_jobs = octx->n_threads;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs);
|
||||
}
|
||||
const uint32_t n_matmul_jobs = octx->n_threads;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
int op_matmul(struct htp_ops_context * octx) {
|
||||
htp_matmul_tensors_preamble;
|
||||
|
||||
#ifndef HTP_HAS_HMX
|
||||
return op_matmul_hvx(octx);
|
||||
#else
|
||||
if (!octx->ctx->hmx_enabled) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
// HMX weight tile requires N to be 32-aligned.
|
||||
if (src0->ne[1] % 32 != 0) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
// HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights.
|
||||
// Other types fall back to HVX.
|
||||
uint32_t wtype = src0->type;
|
||||
if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
// Quantised HMX path requires K aligned to 256 (x4x2 super-block).
|
||||
// F16 HMX path requires K aligned to 32 (tile width).
|
||||
if (wtype != HTP_TYPE_F16 && src0->ne[0] % 256 != 0) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
if (wtype == HTP_TYPE_F16 && src0->ne[0] % 32 != 0) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
const bool is_batched = (src0->ne[2] * src0->ne[3] > 1 || src1->ne[2] * src1->ne[3] > 1);
|
||||
|
||||
// Quantised HMX kernels only handle flat 2D matmul (host already rejects
|
||||
// batched quantised, but guard here too). F16 batched matmul is handled
|
||||
// by the dedicated wrapper in hmx-matmul-ops.c.
|
||||
if (is_batched && src0->type != HTP_TYPE_F16) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
// HMX assumes contiguous row-major layout. Fall back for permuted
|
||||
// tensors where strides are non-monotonic (e.g. transposed KV cache).
|
||||
if (src0->nb[0] > src0->nb[1] || src1->nb[0] > src1->nb[1]) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
// M alignment: when M > 32 but not 32-aligned, we split into
|
||||
// HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows).
|
||||
// When M <= 32 and not 32-aligned, fall back entirely to HVX.
|
||||
const int m_total = (int) src1->ne[1];
|
||||
const int m_tail = m_total % 32;
|
||||
const int m_hmx = m_total - m_tail;
|
||||
|
||||
if (m_hmx == 0) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
// Always re-quantize src1 since HMX kernel overwrites vtcm/spad,
|
||||
// so any previously cached quantized data is invalid.
|
||||
octx->src1_spad.src = NULL;
|
||||
|
||||
int k = (int) src0->ne[0]; // inner dimension
|
||||
int n = (int) src0->ne[1]; // weight columns
|
||||
|
||||
// --- Phase 1: HMX on the first m_hmx (32-aligned) rows ---
|
||||
int ret = -1;
|
||||
|
||||
// Row strides in elements. For compact tensors these equal k; for
|
||||
// permuted attention views they can be larger, so pass the real stride.
|
||||
const int act_stride = (int)(src1->nb[1] / sizeof(float));
|
||||
const int wgt_stride = (int)(src0->nb[1] / sizeof(__fp16));
|
||||
|
||||
if (src0->type == HTP_TYPE_F16) {
|
||||
if (is_batched) {
|
||||
hmx_matmul_w16a32_batched_params_t batch_params = {
|
||||
.dst = (float *) dst->data,
|
||||
.activation = (float *) src1->data,
|
||||
.permuted_weight = (const __fp16 *) src0->data,
|
||||
.m = m_hmx,
|
||||
.k = k,
|
||||
.n = n,
|
||||
.act_stride = act_stride,
|
||||
.weight_stride = wgt_stride,
|
||||
.dst_stride = (int) (dst->nb[1] / sizeof(float)),
|
||||
.ne02 = ne02,
|
||||
.ne03 = ne03,
|
||||
.ne12 = ne12,
|
||||
.ne13 = ne13,
|
||||
.src0_nb2 = src0->nb[2],
|
||||
.src0_nb3 = src0->nb[3],
|
||||
.src1_nb2 = src1->nb[2],
|
||||
.src1_nb3 = src1->nb[3],
|
||||
.dst_nb2 = dst->nb[2],
|
||||
.dst_nb3 = dst->nb[3],
|
||||
};
|
||||
ret = hmx_mat_mul_permuted_w16a32_batched(octx->ctx, &batch_params);
|
||||
} else {
|
||||
ret = hmx_mat_mul_permuted_w16a32(octx->ctx,
|
||||
(float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data,
|
||||
m_hmx, k, n, act_stride, wgt_stride);
|
||||
}
|
||||
} else {
|
||||
ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx,
|
||||
(float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data,
|
||||
m_hmx, k, n, (int) src0->type);
|
||||
}
|
||||
|
||||
if (ret != 0) {
|
||||
FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret);
|
||||
return op_matmul(octx);
|
||||
}
|
||||
|
||||
// --- Phase 2: HVX on the remaining m_tail rows ---
|
||||
if (m_tail > 0) {
|
||||
// copy of src1 and dst
|
||||
struct htp_tensor src1_tail = *src1;
|
||||
struct htp_tensor dst_tail = *dst;
|
||||
|
||||
src1_tail.ne[1] = m_tail; // only tail rows
|
||||
dst_tail.ne[1] = m_tail; // only tail rows
|
||||
|
||||
// Offset activation and dst pointers past the HMX-processed rows.
|
||||
// Use nb[1] (row stride in bytes) to compute the byte offset.
|
||||
src1_tail.data += (uint32_t) m_hmx * src1->nb[1];
|
||||
dst_tail.data += (uint32_t) m_hmx * dst->nb[1];
|
||||
|
||||
octx->src[1] = &src1_tail;
|
||||
octx->dst = &dst_tail;
|
||||
|
||||
FARF(HIGH, "hmx-matmul: HVX tail m_tail %d src1 %p dst %p", m_tail, (void *) src1_tail.data, (void *) dst_tail.data);
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
return 0;
|
||||
#endif // HTP_HAS_HMX
|
||||
}
|
||||
|
||||
int op_matmul_id(struct htp_ops_context * octx) {
|
||||
htp_matmul_tensors_preamble;
|
||||
|
||||
@@ -2950,7 +3089,7 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
||||
struct htp_matmul_context * mmctx = &mmctx_struct;
|
||||
mmctx->octx = octx;
|
||||
|
||||
struct htp_tensor * restrict ids = &octx->src2;
|
||||
const struct htp_tensor * restrict ids = octx->src[2];
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
@@ -3003,11 +3142,17 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
// Place src1 spad first. We use it for dyn.quant and may reuse in subseq ops.
|
||||
octx->src1_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
octx->src2_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size;
|
||||
|
||||
octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL;
|
||||
octx->src0_spad.src = NULL;
|
||||
octx->src2_spad.src = NULL;
|
||||
octx->dst_spad.src = NULL;
|
||||
|
||||
octx->src0_spad.stride = src0_row_size_padded;
|
||||
octx->src1_spad.stride = src1_row_size;
|
||||
|
||||
@@ -3031,20 +3176,18 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
||||
}
|
||||
}
|
||||
|
||||
// Setup worker pool callbacks
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) {
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)
|
||||
return HTP_STATUS_OK;
|
||||
|
||||
if (octx->src1_spad.src != src1) {
|
||||
const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
|
||||
mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
|
||||
octx->ctx->prev_src1_spad = octx->src1_spad.data;
|
||||
} else {
|
||||
octx->src1_spad.data = octx->ctx->prev_src1_spad;
|
||||
octx->src1_spad.src = src1;
|
||||
}
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
const uint32_t n_matmul_jobs = octx->n_threads;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs);
|
||||
}
|
||||
const uint32_t n_matmul_jobs = octx->n_threads;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
struct htp_repeat_context {
|
||||
@@ -32,8 +32,8 @@ struct htp_repeat_context {
|
||||
static void repeat_job_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
const struct htp_repeat_context * rctx = (const struct htp_repeat_context *) data;
|
||||
struct htp_ops_context * octx = rctx->octx;
|
||||
const struct htp_tensor * src = &octx->src0;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src = octx->src[0];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
const uint32_t ne00 = src->ne[0];
|
||||
const uint32_t ne01 = src->ne[1];
|
||||
@@ -98,8 +98,8 @@ static void repeat_job_per_thread(unsigned int nth, unsigned int ith, void * dat
|
||||
}
|
||||
|
||||
int op_repeat(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
// Validate that dst dims are multiples of src dims
|
||||
if (dst->ne[0] % src0->ne[0] != 0 ||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we can't include ggml.h
|
||||
@@ -253,10 +253,10 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_rope_context * rctx = (struct htp_rope_context *) data;
|
||||
struct htp_ops_context * octx = rctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
const struct htp_tensor * src2 = &octx->src2;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
const struct htp_tensor * src2 = octx->src[2];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
htp_rope_preamble;
|
||||
|
||||
@@ -284,7 +284,7 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
|
||||
dma_queue * dma_queue = octx->ctx->dma[ith];
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
const float * freq_factors = src2->data ? (const float *) src2->data : NULL;
|
||||
const float * freq_factors = src2 ? (const float *) src2->data : NULL;
|
||||
|
||||
uint32_t ir = 0;
|
||||
uint32_t prev_i2 = (uint32_t) -1;
|
||||
@@ -384,10 +384,10 @@ done:
|
||||
static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
const struct htp_tensor * src2 = &octx->src2;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
const struct htp_tensor * src2 = octx->src[2];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
const char * op_type = "rope-f32";
|
||||
|
||||
@@ -424,19 +424,16 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
// Assign sizes
|
||||
octx->src0_spad.size_per_thread = src0_spad_per_thread;
|
||||
octx->dst_spad.size_per_thread = dst_spad_per_thread;
|
||||
octx->src0_spad.size = n_threads * src0_spad_per_thread;
|
||||
octx->dst_spad.size = n_threads * dst_spad_per_thread;
|
||||
octx->src1_spad.size = 0;
|
||||
|
||||
// Assign pointers
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = NULL;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL;
|
||||
octx->src1_spad.data = NULL; octx->src1_spad.src = NULL;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.src = NULL;
|
||||
|
||||
// Fill context
|
||||
struct htp_rope_context rctx;
|
||||
memset(&rctx, 0, sizeof(struct htp_rope_context));
|
||||
|
||||
@@ -483,7 +480,7 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
||||
int op_rope(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (octx->src0.type) {
|
||||
switch (octx->src[0]->type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = execute_op_rope_f32(octx);
|
||||
break;
|
||||
|
||||
@@ -14,33 +14,37 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
#define set_rows_preamble \
|
||||
const uint32_t ne00 = octx->src0.ne[0]; \
|
||||
const uint32_t ne01 = octx->src0.ne[1]; \
|
||||
const uint32_t ne02 = octx->src0.ne[2]; \
|
||||
const uint32_t ne03 = octx->src0.ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = octx->src1.ne[0]; \
|
||||
const uint32_t ne11 = octx->src1.ne[1]; \
|
||||
const uint32_t ne12 = octx->src1.ne[2]; \
|
||||
\
|
||||
const uint32_t nb01 = octx->src0.nb[1]; \
|
||||
const uint32_t nb02 = octx->src0.nb[2]; \
|
||||
const uint32_t nb03 = octx->src0.nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = octx->src1.nb[0]; \
|
||||
const uint32_t nb11 = octx->src1.nb[1]; \
|
||||
const uint32_t nb12 = octx->src1.nb[2]; \
|
||||
\
|
||||
const uint32_t nb1 = octx->dst.nb[1]; \
|
||||
const uint32_t nb2 = octx->dst.nb[2]; \
|
||||
const uint32_t nb3 = octx->dst.nb[3]; \
|
||||
\
|
||||
const uint32_t ne1 = octx->dst.ne[1]; \
|
||||
\
|
||||
#define set_rows_preamble \
|
||||
const uint32_t ne00 = octx->src[0]->ne[0]; \
|
||||
const uint32_t ne01 = octx->src[0]->ne[1]; \
|
||||
const uint32_t ne02 = octx->src[0]->ne[2]; \
|
||||
const uint32_t ne03 = octx->src[0]->ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = octx->src[1]->ne[0]; \
|
||||
const uint32_t ne11 = octx->src[1]->ne[1]; \
|
||||
const uint32_t ne12 = octx->src[1]->ne[2]; \
|
||||
const uint32_t ne13 = octx->src[1]->ne[3]; \
|
||||
\
|
||||
const uint32_t nb01 = octx->src[0]->nb[1]; \
|
||||
const uint32_t nb02 = octx->src[0]->nb[2]; \
|
||||
const uint32_t nb03 = octx->src[0]->nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = octx->src[1]->nb[0]; \
|
||||
const uint32_t nb11 = octx->src[1]->nb[1]; \
|
||||
const uint32_t nb12 = octx->src[1]->nb[2]; \
|
||||
\
|
||||
const uint32_t nb1 = octx->dst->nb[1]; \
|
||||
const uint32_t nb2 = octx->dst->nb[2]; \
|
||||
const uint32_t nb3 = octx->dst->nb[3]; \
|
||||
\
|
||||
const uint32_t ne0 = octx->dst->ne[0]; \
|
||||
const uint32_t ne1 = octx->dst->ne[1]; \
|
||||
const uint32_t ne2 = octx->dst->ne[2]; \
|
||||
const uint32_t ne3 = octx->dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nr = ne01;
|
||||
|
||||
struct htp_set_rows_context {
|
||||
@@ -56,12 +60,14 @@ static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da
|
||||
|
||||
set_rows_preamble;
|
||||
|
||||
uint64_t qt = HAP_perf_get_qtimer_count();
|
||||
|
||||
// parallelize by rows of src0
|
||||
const uint32_t dr = srctx->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
|
||||
const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32);
|
||||
|
||||
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
|
||||
@@ -70,7 +76,7 @@ static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
|
||||
const uint32_t i10 = i;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
||||
uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
|
||||
if (i1 >= ne1) {
|
||||
@@ -78,14 +84,18 @@ static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da
|
||||
continue;
|
||||
}
|
||||
|
||||
const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
|
||||
const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
|
||||
const uintptr_t src0_ptr = octx->src[0]->data + i*nb01 + i02*nb02 + i03*nb03;
|
||||
const uintptr_t dst_ptr = octx->dst->data + i1*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
// copy row
|
||||
hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt);
|
||||
FARF(HIGH, "set-rows-f32-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt);
|
||||
}
|
||||
|
||||
static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *data) {
|
||||
@@ -94,12 +104,14 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da
|
||||
|
||||
set_rows_preamble;
|
||||
|
||||
uint64_t qt = HAP_perf_get_qtimer_count();
|
||||
|
||||
// parallelize by rows of src0
|
||||
const uint32_t dr = srctx->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
|
||||
const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32);
|
||||
|
||||
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
|
||||
@@ -108,7 +120,7 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
|
||||
const uint32_t i10 = i;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
||||
uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
|
||||
if (i1 >= ne1) {
|
||||
@@ -116,13 +128,17 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
|
||||
uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
|
||||
const uint8_t* src0_ptr = (const uint8_t *) octx->src[0]->data + i*nb01 + i02*nb02 + i03*nb03;
|
||||
uint8_t* dst_ptr = (uint8_t *) octx->dst->data + i1*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt);
|
||||
FARF(HIGH, "set-rows-f16-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt);
|
||||
}
|
||||
|
||||
int op_set_rows(struct htp_ops_context * octx) {
|
||||
@@ -130,15 +146,15 @@ int op_set_rows(struct htp_ops_context * octx) {
|
||||
|
||||
const uint32_t n_threads = MIN(nr, octx->n_threads);
|
||||
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
if (octx->src[0]->type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) {
|
||||
if (octx->dst->type != HTP_TYPE_F32 && octx->dst->type != HTP_TYPE_F16) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
|
||||
if (octx->src[1]->type != HTP_TYPE_I32 && octx->src[1]->type != HTP_TYPE_I64) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
@@ -153,7 +169,7 @@ int op_set_rows(struct htp_ops_context * octx) {
|
||||
|
||||
srctx.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads;
|
||||
|
||||
switch(octx->dst.type) {
|
||||
switch(octx->dst->type) {
|
||||
case HTP_TYPE_F32:
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_threads);
|
||||
break;
|
||||
|
||||
@@ -15,68 +15,89 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
#define htp_softmax_preamble3 \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t ne10 = (src1->ne[0]) ? src1->ne[0] : 1; \
|
||||
const uint32_t ne11 = (src1->ne[0]) ? src1->ne[1] : 1; \
|
||||
const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1; \
|
||||
const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1; \
|
||||
\
|
||||
const uint32_t nb10 = (src1->ne[0]) ? src1->nb[0] : 1; \
|
||||
const uint32_t nb11 = (src1->ne[0]) ? src1->nb[1] : 1; \
|
||||
const uint32_t nb12 = (src1->ne[0]) ? src1->nb[2] : 1; \
|
||||
const uint32_t nb13 = (src1->ne[0]) ? src1->nb[3] : 1; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
#define htp_softmax_preamble3 \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t ne10 = src1 ? src1->ne[0] : 1; \
|
||||
const uint32_t ne11 = src1 ? src1->ne[1] : 1; \
|
||||
const uint32_t ne12 = src1 ? src1->ne[2] : 1; \
|
||||
const uint32_t ne13 = src1 ? src1->ne[3] : 1; \
|
||||
\
|
||||
const uint32_t nb10 = src1 ? src1->nb[0] : 1; \
|
||||
const uint32_t nb11 = src1 ? src1->nb[1] : 1; \
|
||||
const uint32_t nb12 = src1 ? src1->nb[2] : 1; \
|
||||
const uint32_t nb13 = src1 ? src1->nb[3] : 1; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
struct htp_softmax_context {
|
||||
struct htp_ops_context * octx;
|
||||
|
||||
bool use_f16;
|
||||
bool use_src1;
|
||||
|
||||
uint32_t n_head;
|
||||
uint32_t n_head_log2;
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
uint32_t src0_nrows_per_thread;
|
||||
struct fastdiv_values fastdiv_ne01;
|
||||
struct fastdiv_values fastdiv_ne02;
|
||||
struct fastdiv_values fastdiv_ne12; // For mask broadcasting
|
||||
struct fastdiv_values fastdiv_ne13; // For mask broadcasting
|
||||
size_t spad_stride;
|
||||
|
||||
struct htp_ops_context * octx;
|
||||
uint32_t src0_nrows_per_thread;
|
||||
};
|
||||
|
||||
static void apply_mask(float * restrict wp0,
|
||||
const float * restrict mp_f32,
|
||||
const __fp16 * restrict mp_f16,
|
||||
uint32_t ne00,
|
||||
float slope,
|
||||
bool use_f16) {
|
||||
if (!mp_f32) {
|
||||
return;
|
||||
}
|
||||
if (use_f16) {
|
||||
for (uint32_t i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * (float) mp_f16[i];
|
||||
}
|
||||
} else {
|
||||
for (uint32_t i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * mp_f32[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
|
||||
memset(smctx, 0, sizeof(struct htp_softmax_context));
|
||||
|
||||
memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float));
|
||||
memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float));
|
||||
memcpy(&smctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
|
||||
|
||||
smctx->n_head = src0->ne[2];
|
||||
@@ -85,8 +106,8 @@ static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_
|
||||
smctx->m0 = powf(2.0f, -(smctx->max_bias) / smctx->n_head_log2);
|
||||
smctx->m1 = powf(2.0f, -(smctx->max_bias / 2.0f) / smctx->n_head_log2);
|
||||
|
||||
smctx->use_src1 = (src1->ne[0] != 0);
|
||||
smctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
|
||||
smctx->use_src1 = (src1 != 0);
|
||||
smctx->use_f16 = (src1 != 0) && (src1->type == HTP_TYPE_F16);
|
||||
|
||||
smctx->octx = octx;
|
||||
|
||||
@@ -97,8 +118,8 @@ static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_
|
||||
if (ne01 > 0) smctx->fastdiv_ne01 = init_fastdiv_values(ne01);
|
||||
if (ne02 > 0) smctx->fastdiv_ne02 = init_fastdiv_values(ne02);
|
||||
|
||||
const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1;
|
||||
const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1;
|
||||
const uint32_t ne12 = src1 ? src1->ne[2] : 1;
|
||||
const uint32_t ne13 = src1 ? src1->ne[3] : 1;
|
||||
|
||||
if (ne12 > 0) smctx->fastdiv_ne12 = init_fastdiv_values(ne12);
|
||||
if (ne13 > 0) smctx->fastdiv_ne13 = init_fastdiv_values(ne13);
|
||||
@@ -139,10 +160,7 @@ static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src,
|
||||
}
|
||||
}
|
||||
|
||||
static void hvx_fast_softmax_f32(const uint8_t * restrict src,
|
||||
uint8_t * restrict dst,
|
||||
uint8_t * restrict pad,
|
||||
const int num_elems) {
|
||||
static void hvx_fast_softmax_f32(const uint8_t * restrict src, uint8_t * restrict dst, uint8_t * restrict pad, const int num_elems) {
|
||||
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
||||
HVX_Vector * restrict v_pad = (HVX_Vector *) pad;
|
||||
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
||||
@@ -188,27 +206,20 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
|
||||
}
|
||||
}
|
||||
|
||||
static float hvx_softmax_f32(const uint8_t * restrict src,
|
||||
uint8_t * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const int num_elems,
|
||||
const float max) {
|
||||
static float hvx_softmax_f32(const uint8_t * restrict src, uint8_t * restrict dst, uint8_t * restrict spad, const int num_elems, const float max) {
|
||||
hvx_sub_scalar_f32(spad, src, max, num_elems);
|
||||
|
||||
hvx_exp_f32(dst, spad, num_elems, false);
|
||||
|
||||
float sum = hvx_reduce_sum_f32(dst, num_elems);
|
||||
|
||||
return sum;
|
||||
return hvx_reduce_sum_f32(dst, num_elems);
|
||||
}
|
||||
|
||||
static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_softmax_context * smctx = (struct htp_softmax_context *) data;
|
||||
struct htp_ops_context * octx = smctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
htp_softmax_preamble3;
|
||||
|
||||
@@ -223,22 +234,26 @@ static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
uint64_t qt = HAP_perf_get_qtimer_count();
|
||||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
|
||||
if (!hex_is_aligned((void *) src0->data, VLEN) || !hex_is_aligned((void *) dst->data, VLEN)) {
|
||||
is_aligned = 0;
|
||||
FARF(HIGH, "softmax-f32: unaligned addresses in elementwise op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
// Only use the fast path when aligned AND row size is multiple of VLEN (128 bytes)
|
||||
// The fast path (hvx_fast_softmax_f32) doesn't handle tail elements
|
||||
// The non-opt path uses hvx_softmax_f32 which properly handles all sizes via its helper functions
|
||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
}
|
||||
|
||||
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * smctx->spad_stride);
|
||||
uint8_t * src1_spad_data = octx->src1_spad.data + (ith * smctx->spad_stride);
|
||||
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * smctx->spad_stride);
|
||||
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
|
||||
uint8_t * src1_spad_data = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
|
||||
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
|
||||
|
||||
float * wp0 = (float *) src0_spad_data;
|
||||
float * wp1 = (float *) src1_spad_data;
|
||||
@@ -278,47 +293,29 @@ static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
// ALiBi
|
||||
if (i2 != prev_i2) {
|
||||
const uint32_t h = i2; // head
|
||||
|
||||
slope = (smctx->max_bias > 0.0f) ?
|
||||
h < smctx->n_head_log2 ?
|
||||
powf(smctx->m0, h + 1) :
|
||||
powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) :
|
||||
1.0f;
|
||||
slope = (smctx->max_bias > 0.0f) ? h < smctx->n_head_log2 ? powf(smctx->m0, h + 1) : powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) : 1.0f;
|
||||
prev_i2 = i2;
|
||||
}
|
||||
|
||||
float * sp = (float *) ((char *) octx->src0.data + i1 * nb01 + i2 * nb02 + i3 * nb03);
|
||||
float * dp = (float *) ((char *) octx->dst.data + i1 * nb1 + i2 * nb2 + i3 * nb3);
|
||||
float * sp = (float *) ((char *) src0->data + i1 * nb01 + i2 * nb02 + i3 * nb03);
|
||||
float * dp = (float *) ((char *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3);
|
||||
|
||||
// broadcast the mask across rows
|
||||
__fp16 * mp_f16 = (smctx->use_src1) ?
|
||||
(__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
|
||||
NULL;
|
||||
float * mp_f32 = (smctx->use_src1) ?
|
||||
(float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
|
||||
NULL;
|
||||
__fp16 * mp_f16 = (smctx->use_src1) ? (__fp16 *) ((char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13) : NULL;
|
||||
float * mp_f32 = (smctx->use_src1) ? (float *) ((char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13) : NULL;
|
||||
|
||||
if ((1 == opt_path) && (mp_f32) && !(smctx->use_f16)) {
|
||||
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale,
|
||||
(const uint8_t *) mp_f32, slope);
|
||||
} else {
|
||||
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale, (const uint8_t *) mp_f32, slope);
|
||||
hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
|
||||
} else if (1 == opt_path) {
|
||||
hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale);
|
||||
if (mp_f32) {
|
||||
if (smctx->use_f16) {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * (float) mp_f16[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * mp_f32[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
apply_mask(wp0, mp_f32, mp_f16, ne00, slope, smctx->use_f16);
|
||||
hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
|
||||
} else {
|
||||
// Non-optimized path: uses HVX helper functions that properly handle all tensor sizes
|
||||
// including non-multiples of 32 (the HVX vector lane count for f32)
|
||||
hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale);
|
||||
apply_mask(wp0, mp_f32, mp_f16, ne00, slope, smctx->use_f16);
|
||||
float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
|
||||
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
|
||||
sum = sum > 0.0 ? (1.0 / sum) : 1;
|
||||
@@ -326,54 +323,47 @@ static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
smctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
|
||||
ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt);
|
||||
FARF(HIGH, "softmax-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u : opt %u f16 %u usec %u\n", ith, nth,
|
||||
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
|
||||
ne0, ne1, ne2, ne3, opt_path, smctx->use_f16, (unsigned) qt);
|
||||
}
|
||||
|
||||
static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
struct htp_softmax_context smctx;
|
||||
const char * op_type = "softmax-f32";
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_SOFTMAX:
|
||||
init_softmax_ctx(&smctx, octx);
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported Op %u\n", octx->op);
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
init_softmax_ctx(&smctx, octx);
|
||||
|
||||
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
|
||||
|
||||
smctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
|
||||
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t src1_row_size = src0_row_size;
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
|
||||
octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
|
||||
// 4 rows per thread, padded to HVX vector size
|
||||
octx->src0_spad.size_per_thread = hex_round_up(4 * src0_row_size, 128);
|
||||
octx->src1_spad.size_per_thread = hex_round_up(4 * src1_row_size, 128);
|
||||
octx->dst_spad.size_per_thread = hex_round_up(4 * dst_row_size, 128);
|
||||
|
||||
// Use stride for calculating offset
|
||||
smctx.spad_stride = hex_round_up(src0_row_size, 128);
|
||||
octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads;
|
||||
octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads;
|
||||
octx->dst_spad.size = octx->dst_spad.size_per_thread * n_threads;
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||
|
||||
if (src1->ne[0]) {
|
||||
FARF(HIGH,
|
||||
"%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
|
||||
if (src1) {
|
||||
FARF(HIGH, "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
|
||||
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
|
||||
octx->dst_spad.size);
|
||||
@@ -385,19 +375,17 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
||||
|
||||
// Make sure the reserved vtcm size is sufficient
|
||||
if (octx->ctx->vtcm_size < spad_size) {
|
||||
FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
||||
spad_size);
|
||||
FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, spad_size);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL;
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
smctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_threads);
|
||||
}
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) return err;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_threads);
|
||||
|
||||
return err;
|
||||
}
|
||||
@@ -405,7 +393,7 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
||||
int op_softmax(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (octx->src0.type) {
|
||||
switch (octx->src[0]->type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = execute_op_softmax_f32(octx);
|
||||
break;
|
||||
|
||||
@@ -16,14 +16,14 @@
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "hex-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define htp_ssm_conv_tensors_preamble \
|
||||
struct htp_tensor * restrict src0 = &octx->src0; \
|
||||
struct htp_tensor * restrict src1 = &octx->src1; \
|
||||
struct htp_tensor * restrict dst = &octx->dst; \
|
||||
#define htp_ssm_conv_tensors_preamble \
|
||||
const struct htp_tensor * restrict src0 = octx->src[0]; \
|
||||
const struct htp_tensor * restrict src1 = octx->src[1]; \
|
||||
const struct htp_tensor * restrict dst = octx->dst; \
|
||||
struct htp_spad * restrict src0_spad = &octx->src0_spad; \
|
||||
struct htp_spad * restrict src1_spad = &octx->src1_spad; \
|
||||
struct htp_spad * restrict dst_spad = &octx->dst_spad; \
|
||||
@@ -289,9 +289,9 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) {
|
||||
// Compute gather scratchpad size for src0 and src1
|
||||
const size_t gather_spad_size = n_threads * VLEN * 2;
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size; octx->src0_spad.src = NULL;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL;
|
||||
|
||||
FARF(HIGH, "ssm_conv-f32: gather-spad:%zu spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n",
|
||||
gather_spad_size, octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread,
|
||||
@@ -323,8 +323,9 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) {
|
||||
}
|
||||
|
||||
int op_ssm_conv(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (dst->type) {
|
||||
case HTP_TYPE_F32:
|
||||
|
||||
@@ -14,13 +14,13 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
#define sum_rows_preamble \
|
||||
struct htp_tensor *src0 = &octx->src0;\
|
||||
struct htp_tensor *dst = &octx->dst; \
|
||||
\
|
||||
#define sum_rows_preamble \
|
||||
const struct htp_tensor *src0 = octx->src[0]; \
|
||||
const struct htp_tensor *dst = octx->dst; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
@@ -94,7 +94,7 @@ static void sum_rows_thread_f32(unsigned int nth, unsigned int ith, void *data)
|
||||
int op_sum_rows(struct htp_ops_context * octx) {
|
||||
sum_rows_preamble;
|
||||
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
if (octx->src[0]->type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
struct htp_unary_context {
|
||||
@@ -267,8 +267,8 @@ static void softplus_f32(const float * restrict src,
|
||||
static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
|
||||
struct htp_ops_context * octx = uctx->octx;
|
||||
const struct htp_tensor * src = &octx->src0;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src = octx->src[0];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
htp_unary_preamble;
|
||||
|
||||
@@ -387,8 +387,8 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
const char * op_type = NULL;
|
||||
|
||||
@@ -490,7 +490,7 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
int op_unary(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (octx->src0.type) {
|
||||
switch (octx->src[0]->type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = execute_op_unary_f32(octx);
|
||||
break;
|
||||
|
||||
@@ -1115,6 +1115,32 @@ class ggml_webgpu_shader_lib {
|
||||
std::string type_upper = type_str;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
switch (key.src_type)
|
||||
{
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
{
|
||||
// Quantized types using u32 buffers for portability.
|
||||
defines.push_back("SRC_TYPE=u32");
|
||||
defines.push_back("U32_DEQUANT_HELPERS");
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
defines.push_back(std::string("SRC_TYPE=") + type_str);
|
||||
}
|
||||
}
|
||||
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back(type_upper + "_T");
|
||||
defines.push_back(type_upper);
|
||||
@@ -1125,7 +1151,6 @@ class ggml_webgpu_shader_lib {
|
||||
variant += "_";
|
||||
variant += type_str;
|
||||
|
||||
defines.push_back(std::string("SRC_TYPE=") + type_str);
|
||||
defines.push_back("DST_TYPE=f32");
|
||||
|
||||
if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
|
||||
@@ -1593,11 +1618,35 @@ class ggml_webgpu_shader_lib {
|
||||
break;
|
||||
default:
|
||||
{
|
||||
// quantized types
|
||||
std::string type_upper = src0_name;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
|
||||
switch (context.src0->type)
|
||||
{
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
{
|
||||
// Quantized types using u32 buffers for portability.
|
||||
defines.push_back("SRC0_TYPE=u32");
|
||||
defines.push_back("U32_DEQUANT_HELPERS");
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
|
||||
}
|
||||
}
|
||||
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back(type_upper + "_T");
|
||||
defines.push_back(type_upper);
|
||||
|
||||
@@ -97,6 +97,14 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
|
||||
|
||||
/* End Constants */
|
||||
|
||||
static inline wgpu::CallbackMode ggml_webgpu_callback_mode() {
|
||||
#ifdef __EMSCRIPTEN__
|
||||
return wgpu::CallbackMode::AllowProcessEvents;
|
||||
#else
|
||||
return wgpu::CallbackMode::AllowSpontaneous;
|
||||
#endif
|
||||
}
|
||||
|
||||
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to
|
||||
// their locations.
|
||||
static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
|
||||
@@ -474,7 +482,7 @@ static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) {
|
||||
|
||||
const wgpu::WaitStatus wait_status = ctx->instance.WaitAny(
|
||||
ctx->queue.OnSubmittedWorkDone(
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
ggml_webgpu_callback_mode(),
|
||||
[&callback_status, &callback_message](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
||||
callback_status = status;
|
||||
callback_message = std::string(message);
|
||||
@@ -494,7 +502,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,
|
||||
std::string callback_message;
|
||||
|
||||
const wgpu::WaitStatus wait_status = ctx->instance.WaitAny(
|
||||
buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
|
||||
buffer.MapAsync(mode, offset, size, ggml_webgpu_callback_mode(),
|
||||
[&callback_status, &callback_message](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
||||
callback_status = status;
|
||||
callback_message = std::string(message);
|
||||
@@ -526,7 +534,11 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
|
||||
encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
|
||||
wgpu::CommandBuffer commands = encoder.Finish();
|
||||
ctx->queue.Submit(1, &commands);
|
||||
ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
|
||||
if (!ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0,
|
||||
ctx->debug_host_buf.GetSize())) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Debug buffer map failed\n");
|
||||
return;
|
||||
}
|
||||
const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
|
||||
std::cout << "debug[0]: " << debug_data[0] << "\n";
|
||||
ctx->debug_host_buf.Unmap();
|
||||
@@ -542,7 +554,7 @@ static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context &
|
||||
auto ts_bufs = command.timestamp_query_bufs;
|
||||
|
||||
wgpu::Future f = ts_bufs.host_buf.MapAsync(
|
||||
wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
|
||||
wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), ggml_webgpu_callback_mode(),
|
||||
[ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
||||
if (status != wgpu::MapAsyncStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str());
|
||||
@@ -3420,7 +3432,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
|
||||
ctx->webgpu_global_ctx->instance.WaitAny(
|
||||
ctx->webgpu_global_ctx->instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
&options, ggml_webgpu_callback_mode(),
|
||||
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
|
||||
if (status != wgpu::RequestAdapterStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
||||
@@ -3449,13 +3461,15 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
// Only support square f16 matrices of size 8 or 16 for now
|
||||
// Accept f16 subgroup matrix configurations (square or non-square).
|
||||
// NVIDIA GPUs typically report square configs (e.g. 16x16x16),
|
||||
// while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16).
|
||||
// The shaders are already parameterized to handle any M/N/K dimensions.
|
||||
bool valid_subgroup_matrix_config = false;
|
||||
if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
||||
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
|
||||
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
|
||||
if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
|
||||
config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
|
||||
if (config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
|
||||
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
|
||||
ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M;
|
||||
ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N;
|
||||
@@ -3491,8 +3505,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
dev_desc.requiredFeatures = required_features.data();
|
||||
dev_desc.requiredFeatureCount = required_features.size();
|
||||
dev_desc.SetDeviceLostCallback(
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
||||
ggml_webgpu_callback_mode(),
|
||||
[ctx](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
||||
if (reason == wgpu::DeviceLostReason::Destroyed) {
|
||||
return;
|
||||
}
|
||||
@@ -3525,7 +3539,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
|
||||
ctx->webgpu_global_ctx->instance.WaitAny(
|
||||
ctx->webgpu_global_ctx->adapter.RequestDevice(
|
||||
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
|
||||
&dev_desc, ggml_webgpu_callback_mode(),
|
||||
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
||||
if (status != wgpu::RequestDeviceStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
|
||||
@@ -3793,6 +3807,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
|
||||
break;
|
||||
}
|
||||
// Head dimensions must be divisible by subgroup matrix dimensions
|
||||
if (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k != 0 ||
|
||||
src2->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_n != 0) {
|
||||
break;
|
||||
}
|
||||
// Head dimensions must fit in workgroup memory with minimum tile sizes
|
||||
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
@@ -4046,6 +4065,13 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
||||
ctx.name = GGML_WEBGPU_NAME;
|
||||
ctx.device_count = 0;
|
||||
|
||||
// Keep one Dawn/WebGPU instance alive for the lifetime of the static backend
|
||||
// registry. Recreating it on repeated registry lookups can invalidate
|
||||
// adapter/device references that are still held by the backend/device layer.
|
||||
if (ctx.webgpu_global_ctx != nullptr && ctx.webgpu_global_ctx->instance != nullptr) {
|
||||
return ®
|
||||
}
|
||||
|
||||
wgpu::InstanceDescriptor instance_descriptor{};
|
||||
std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
|
||||
instance_descriptor.requiredFeatures = instance_features.data();
|
||||
@@ -4063,11 +4089,11 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
||||
ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct());
|
||||
ctx.webgpu_global_ctx->instance = std::move(inst);
|
||||
|
||||
// Probe for adapter support
|
||||
wgpu::Adapter adapter;
|
||||
if (ctx.webgpu_global_ctx->instance != nullptr) {
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
|
||||
// probe for adapter support
|
||||
ctx.webgpu_global_ctx->instance.WaitAny(
|
||||
ctx.webgpu_global_ctx->instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
|
||||
@@ -9,35 +9,43 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
|
||||
#endif
|
||||
|
||||
#ifdef U32_DEQUANT_HELPERS
|
||||
fn load_src0_u16_at(byte_offset: u32) -> u32 {
|
||||
let word = src0[byte_offset / 4u];
|
||||
let shift = (byte_offset & 2u) * 8u;
|
||||
return (word >> shift) & 0xFFFFu;
|
||||
fn load_u16_at(
|
||||
buf: ptr<storage, array<u32>, read_write>,
|
||||
byte_offset: u32) -> u32 {
|
||||
let word = buf[byte_offset / 4];
|
||||
let shift = (byte_offset & 0x2) * 8;
|
||||
return (word >> shift) & 0xFFFF;
|
||||
}
|
||||
|
||||
fn load_src0_u32_at(byte_offset: u32) -> u32 {
|
||||
let word_idx = byte_offset / 4u;
|
||||
let shift = (byte_offset & 3u) * 8u;
|
||||
let lo = src0[word_idx];
|
||||
if (shift == 0u) {
|
||||
return lo;
|
||||
}
|
||||
let hi = src0[word_idx + 1u];
|
||||
return (lo >> shift) | (hi << (32u - shift));
|
||||
fn load_u32_at(
|
||||
buf: ptr<storage, array<u32>, read_write>,
|
||||
byte_offset: u32) -> u32 {
|
||||
let word_idx = byte_offset / 4;
|
||||
let shift = (byte_offset & 0x3) * 8;
|
||||
let lo = buf[word_idx];
|
||||
let hi = buf[word_idx + 1];
|
||||
let shifted = (lo >> shift) | (hi << (32 - shift));
|
||||
return select(shifted, lo, shift == 0);
|
||||
}
|
||||
|
||||
fn load_src0_f16_at(byte_offset: u32) -> f16 {
|
||||
let packed = unpack2x16float(load_src0_u16_at(byte_offset));
|
||||
fn load_f16_at(
|
||||
buf: ptr<storage, array<u32>, read_write>,
|
||||
byte_offset: u32) -> f16 {
|
||||
let packed = unpack2x16float(load_u16_at(buf, byte_offset));
|
||||
return f16(packed[0]);
|
||||
}
|
||||
|
||||
fn load_f16_as_f32_at(
|
||||
buf: ptr<storage, array<u32>, read_write>,
|
||||
byte_offset: u32) -> f32 {
|
||||
let word = buf[byte_offset / 4];
|
||||
let shift = (byte_offset & 0x2) * 8;
|
||||
let d_bits = (word >> shift) & 0xFFFF;
|
||||
return unpack2x16float(d_bits)[0];
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q4_0_T
|
||||
struct q4_0 {
|
||||
d: f16,
|
||||
qs: array<f16, 8>
|
||||
};
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef Q4_1_T
|
||||
struct q4_1 {
|
||||
@@ -47,13 +55,6 @@ struct q4_1 {
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef Q5_0_T
|
||||
struct q5_0 {
|
||||
d: f16,
|
||||
qh: array<f16, 2>,
|
||||
qs: array<f16, 8>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef Q5_1_T
|
||||
struct q5_1 {
|
||||
@@ -64,12 +65,6 @@ struct q5_1 {
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef Q8_0_T
|
||||
struct q8_0 {
|
||||
d: f16,
|
||||
qs: array<f16, 16>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef Q8_1_T
|
||||
struct q8_1 {
|
||||
@@ -88,14 +83,6 @@ struct q2_K {
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef Q3_K_T
|
||||
struct q3_K {
|
||||
hmask: array<f16, 16>,
|
||||
qs: array<f16, 32>,
|
||||
scales: array<f16, 6>,
|
||||
d: f16
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN)
|
||||
fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
|
||||
@@ -132,64 +119,6 @@ struct q5_K {
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef Q6_K_T
|
||||
struct q6_K {
|
||||
ql: array<f16, 64>,
|
||||
qh: array<f16, 32>,
|
||||
scales: array<f16, 8>,
|
||||
d: f16
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ2_XXS_T
|
||||
struct iq2_xxs {
|
||||
d: f16,
|
||||
qs: array<f16, 32>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ2_XS_T
|
||||
struct iq2_xs {
|
||||
d: f16,
|
||||
qs: array<f16, 32>,
|
||||
scales: array<f16, 4>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ2_S_T
|
||||
struct iq2_s {
|
||||
d: f16,
|
||||
qs: array<f16, 32>,
|
||||
qh: array<f16, 4>,
|
||||
scales: array<f16, 4>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ3_XXS_T
|
||||
struct iq3_xxs {
|
||||
d: f16,
|
||||
qs: array<f16, 48>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ3_S_T
|
||||
struct iq3_s {
|
||||
d: f16,
|
||||
qs: array<f16, 32>,
|
||||
qh: array<f16, 4>,
|
||||
signs: array<f16, 16>,
|
||||
scales: array<f16, 2>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ1_S_T
|
||||
struct iq1_s {
|
||||
d: f16,
|
||||
qs: array<f16, 16>,
|
||||
qh: array<f16, 8>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ1_M_T
|
||||
struct iq1_m {
|
||||
qs: array<u32, 8>,
|
||||
@@ -198,17 +127,9 @@ struct iq1_m {
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ4_NL_T
|
||||
struct iq4_nl {
|
||||
d: f16,
|
||||
qs: array<f16, 8>,
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ4_XS_T
|
||||
struct iq4_xs {
|
||||
d: f16,
|
||||
scales_h: f16,
|
||||
d_scales_h: u32,
|
||||
scales_l: u32,
|
||||
qs: array<u32, 32>
|
||||
};
|
||||
|
||||
@@ -369,35 +369,35 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
#endif
|
||||
for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
|
||||
let inter_offset = kv_block * SG_MAT_N;
|
||||
var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE);
|
||||
var acc: subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M>>(&inter_shmem, inter_offset, false, KV_TILE);
|
||||
|
||||
var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, 0u, false, HEAD_DIM_QK);
|
||||
var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, 0u, false, HEAD_DIM_QK);
|
||||
|
||||
#ifdef KV_DIRECT
|
||||
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + 0u, true, params.stride_k1);
|
||||
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + 0u, true, params.stride_k1);
|
||||
#else
|
||||
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK);
|
||||
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK);
|
||||
#endif
|
||||
|
||||
var t: u32 = 1u;
|
||||
for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) {
|
||||
let h0 = t * SG_MAT_K;
|
||||
var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h0, false, HEAD_DIM_QK);
|
||||
var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h0, false, HEAD_DIM_QK);
|
||||
#ifdef KV_DIRECT
|
||||
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h0, true, params.stride_k1);
|
||||
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h0, true, params.stride_k1);
|
||||
#else
|
||||
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK);
|
||||
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK);
|
||||
#endif
|
||||
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
|
||||
q_cur = q0;
|
||||
k_cur = k0;
|
||||
|
||||
let h1 = (t + 1u) * SG_MAT_K;
|
||||
var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h1, false, HEAD_DIM_QK);
|
||||
var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h1, false, HEAD_DIM_QK);
|
||||
#ifdef KV_DIRECT
|
||||
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h1, true, params.stride_k1);
|
||||
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h1, true, params.stride_k1);
|
||||
#else
|
||||
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK);
|
||||
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK);
|
||||
#endif
|
||||
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
|
||||
q_cur = q1g;
|
||||
@@ -407,11 +407,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
// handle odd tail
|
||||
if (t < HEAD_DIM_QK / SG_MAT_K) {
|
||||
let h = t * SG_MAT_K;
|
||||
var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h, false, HEAD_DIM_QK);
|
||||
var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h, false, HEAD_DIM_QK);
|
||||
#ifdef KV_DIRECT
|
||||
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h, true, params.stride_k1);
|
||||
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h, true, params.stride_k1);
|
||||
#else
|
||||
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK);
|
||||
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK);
|
||||
#endif
|
||||
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
|
||||
q_cur = qn;
|
||||
@@ -566,7 +566,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
head_dim_block < HEAD_DIM_V;
|
||||
head_dim_block += num_subgroups * SG_MAT_N) {
|
||||
// load O submatrix from shared memory
|
||||
var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(
|
||||
var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M>>(
|
||||
&o_shmem,
|
||||
head_dim_block,
|
||||
false,
|
||||
@@ -574,7 +574,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
);
|
||||
for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) {
|
||||
let p_offset = kv_block * SG_MAT_N;
|
||||
var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
|
||||
var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(
|
||||
&inter_shmem,
|
||||
p_offset,
|
||||
false,
|
||||
@@ -585,7 +585,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
#ifdef KV_DIRECT
|
||||
let v_block_row = kv_tile + kv_block * SG_MAT_N;
|
||||
let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block;
|
||||
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
|
||||
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(
|
||||
&V,
|
||||
v_global_offset,
|
||||
false,
|
||||
@@ -593,7 +593,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
);
|
||||
#else
|
||||
let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V;
|
||||
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
|
||||
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(
|
||||
&kv_shmem,
|
||||
v_block_offset + head_dim_block,
|
||||
false,
|
||||
|
||||
@@ -27,17 +27,18 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
|
||||
#ifdef Q4_0
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q4_0 = src[src_base + offset];
|
||||
let d = f32(block_q4_0.d);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1]));
|
||||
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
for (var j: u32 = 0u; j < 4; j++) {
|
||||
let q_byte_offset = block_byte_base + 2 + j * 4;
|
||||
let q_packed = load_u32_at(&src, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
|
||||
let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
let q_lo = (f32(q_byte & 0xFu) - 8.0) * d;
|
||||
let dst_offset = dst_base + offset * 32 + j * 4 + k;
|
||||
dst[dst_offset] = q_lo;
|
||||
dst[dst_offset + 16] = q_hi;
|
||||
dst[dst_offset + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -64,17 +65,22 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
|
||||
#ifdef Q5_0
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q5_0 = src[src_base + offset];
|
||||
let d = f32(block_q5_0.d);
|
||||
let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));
|
||||
let block_byte_base = (src_base + offset) * 22; // Block stride: 22 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
let qh_packed = load_u32_at(&src, block_byte_base + 2);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));
|
||||
let q_byte_offset = block_byte_base + 6 + j * 4;
|
||||
let q_packed = load_u32_at(&src, q_byte_offset);
|
||||
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
|
||||
let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
|
||||
let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
|
||||
|
||||
let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
|
||||
|
||||
let dst_offset = dst_base + offset * 32 + j * 4 + k;
|
||||
dst[dst_offset] = q_lo;
|
||||
dst[dst_offset + 16] = q_hi;
|
||||
@@ -106,14 +112,15 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
|
||||
#ifdef Q8_0
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q8_0 = src[src_base + offset];
|
||||
let d = f32(block_q8_0.d);
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let block_byte_base = (src_base + offset) * 34; // Block stride: 34 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
for (var j: u32 = 0u; j < 8u; j++) {
|
||||
let q_byte_offset = block_byte_base + 2u + j * 4u;
|
||||
let q_packed = load_u32_at(&src, q_byte_offset);
|
||||
for (var k: u32 = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d;
|
||||
let dst_offset = dst_base + offset * 32 + j * 4 + k;
|
||||
let dst_offset = dst_base + offset * 32u + j * 4u + k;
|
||||
dst[dst_offset] = q_val;
|
||||
}
|
||||
}
|
||||
@@ -152,36 +159,42 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
|
||||
#ifdef Q3_K
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes
|
||||
|
||||
// extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
|
||||
// and 2-bits from the last 4 bytes
|
||||
// Bytes 108-109: f16 scale 'd'
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base + 108);
|
||||
|
||||
// Bytes 96-107: 12 bytes of scales (3 u32s)
|
||||
let kmask1: u32 = 0x03030303;
|
||||
let kmask2: u32 = 0x0f0f0f0f;
|
||||
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
|
||||
}
|
||||
scale_vals[0] = load_u32_at(&src, block_byte_base + 96);
|
||||
scale_vals[1] = load_u32_at(&src, block_byte_base + 100);
|
||||
scale_vals[2] = load_u32_at(&src, block_byte_base + 104);
|
||||
|
||||
var tmp: u32 = scale_vals[2];
|
||||
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
|
||||
scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
|
||||
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
|
||||
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
||||
|
||||
// convert arrays of f16 -> u32
|
||||
// Bytes 0-31: 32 bytes of hmask (8 u32s)
|
||||
var hmask_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0; i < 8; i++) {
|
||||
hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));
|
||||
hmask_vals[i] = load_u32_at(&src, block_byte_base + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 32-95: 64 bytes of qs (16 u32s)
|
||||
var qs_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));
|
||||
for (var i: u32 = 0u; i < 16; i++) {
|
||||
qs_vals[i] = load_u32_at(&src, block_byte_base + 32 + i * 4);
|
||||
}
|
||||
|
||||
var dst_i = dst_base + offset * 256;
|
||||
var is: u32 = 0;
|
||||
var m: u32 = 1;
|
||||
|
||||
// 2 halves of the block (128 elements each)
|
||||
for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
|
||||
// 4 groups (each group has 2 blocks of 16 elements)
|
||||
@@ -191,11 +204,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let sc = get_byte(scale_vals[is / 4], is % 4);
|
||||
is++;
|
||||
let dl = d * (f32(sc) - 32.0);
|
||||
for (var l: u32 = 0u; l < 16u; l++) {
|
||||
|
||||
for (var l: u32 = 0; l < 16; l++) {
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let hm_idx = k + l;
|
||||
let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4);
|
||||
let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4);
|
||||
|
||||
let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
|
||||
let qs_val = (q_byte >> shift) & 3;
|
||||
dst[dst_i] = (f32(qs_val) - hm) * dl;
|
||||
@@ -268,21 +283,27 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
#ifdef Q6_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 210; // Block stride: 210 bytes
|
||||
|
||||
// convert arrays of f16 -> u32
|
||||
// Bytes 208-209: f16 scale 'd'
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base + 208);
|
||||
|
||||
// Bytes 0-127: 128 bytes of ql (32 u32s)
|
||||
var ql_vals: array<u32, 32>;
|
||||
for (var i: u32 = 0; i < 32; i++) {
|
||||
ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));
|
||||
ql_vals[i] = load_u32_at(&src, block_byte_base + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 128-191: 64 bytes of qh (16 u32s)
|
||||
var qh_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));
|
||||
for (var i: u32 = 0; i < 16u; i++) {
|
||||
qh_vals[i] = load_u32_at(&src, block_byte_base + 128 + i * 4u);
|
||||
}
|
||||
|
||||
// Bytes 192-207: 16 bytes of scales (4 u32s)
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
|
||||
scale_vals[i] = load_u32_at(&src, block_byte_base + 192 + i * 4);
|
||||
}
|
||||
|
||||
var dst_i = dst_base + offset * 256;
|
||||
@@ -323,12 +344,14 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
|
||||
#ifdef IQ2_XXS
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 66; // Block stride: 66 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
for (var ib: u32 = 0; ib < 32; ib += 4) {
|
||||
let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1]));
|
||||
let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3]));
|
||||
let aux0_offset = block_byte_base + 2 + ib * 2;
|
||||
let aux1_offset = block_byte_base + 2 + (ib + 2) * 2;
|
||||
let aux0 = load_u32_at(&src, aux0_offset);
|
||||
let aux1 = load_u32_at(&src, aux1_offset);
|
||||
let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = get_byte(aux0, l) * 8;
|
||||
@@ -345,15 +368,19 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
#ifdef IQ2_XS
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 74; // Block stride: 74 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
|
||||
var scale_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.scales[0], block.scales[1])),
|
||||
bitcast<u32>(vec2(block.scales[2], block.scales[3]))
|
||||
load_u32_at(&src, block_byte_base + 66),
|
||||
load_u32_at(&src, block_byte_base + 70)
|
||||
);
|
||||
|
||||
for (var ib: u32 = 0; ib < 32; ib += 4) {
|
||||
let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
|
||||
let db = array<f32, 2>(
|
||||
@@ -361,7 +388,8 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
d * (0.5 + f32(s >> 4)) * 0.25
|
||||
);
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0));
|
||||
let qs_offset = block_byte_base + 2 + (ib + l) * 2;
|
||||
let qs_val = load_u32_at(&src, qs_offset) & 0xFFFF;
|
||||
let ig = (qs_val & 511) * 8;
|
||||
let is = qs_val >> 9;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
@@ -379,21 +407,23 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
|
||||
#ifdef IQ2_S
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 82; // Block stride: 82 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
|
||||
var qs_vals : array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
|
||||
qs_vals[i] = load_u32_at(&src, block_byte_base + 2 + i * 4);
|
||||
}
|
||||
var qh_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.qh[0], block.qh[1])),
|
||||
bitcast<u32>(vec2(block.qh[2], block.qh[3]))
|
||||
);
|
||||
var scale_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.scales[0], block.scales[1])),
|
||||
bitcast<u32>(vec2(block.scales[2], block.scales[3]))
|
||||
);
|
||||
|
||||
var qh_vals: array<u32, 2>;
|
||||
qh_vals[0] = load_u32_at(&src, block_byte_base + 66);
|
||||
qh_vals[1] = load_u32_at(&src, block_byte_base + 70);
|
||||
|
||||
var scale_vals: array<u32, 2>;
|
||||
scale_vals[0] = load_u32_at(&src, block_byte_base + 74);
|
||||
scale_vals[1] = load_u32_at(&src, block_byte_base + 78);
|
||||
|
||||
for (var ib: u32 = 0; ib < 8; ib ++) {
|
||||
let s = get_byte(scale_vals[ib / 4], ib % 4);
|
||||
let db = array<f32, 2>(
|
||||
@@ -419,16 +449,17 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
|
||||
#ifdef IQ3_XXS
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 98; // Block stride: 98 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
for (var ib: u32 = 0; ib < 16; ib += 2) {
|
||||
let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33]));
|
||||
let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2;
|
||||
let sc_sign = load_u32_at(&src, sc_sign_offset);
|
||||
let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let is = (sc_sign >> (7 * l)) & 127;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0));
|
||||
let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
|
||||
let ig1 = get_byte(ig_val, 0);
|
||||
let ig2 = get_byte(ig_val, 1);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
@@ -448,18 +479,22 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
|
||||
#ifdef IQ3_S
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
|
||||
var qh_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.qh[0], block.qh[1])),
|
||||
bitcast<u32>(vec2(block.qh[2], block.qh[3]))
|
||||
load_u32_at(&src, block_byte_base + 66),
|
||||
load_u32_at(&src, block_byte_base + 70)
|
||||
);
|
||||
|
||||
var sign_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0; i < 8; i++) {
|
||||
sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1]));
|
||||
sign_vals[i] = load_u32_at(&src, block_byte_base + 74 + i * 4);
|
||||
}
|
||||
var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1]));
|
||||
|
||||
var scale_vals = load_u32_at(&src, block_byte_base + 106);
|
||||
|
||||
for (var ib: u32 = 0; ib < 4; ib++) {
|
||||
let s = get_byte(scale_vals, ib);
|
||||
let db = array<f32, 2>(
|
||||
@@ -472,7 +507,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let sign_w = sign_vals[ib * 2 + k];
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let signs = get_byte(sign_w, l);
|
||||
let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0));
|
||||
let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
|
||||
let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
|
||||
let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
@@ -493,14 +528,14 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
|
||||
#ifdef IQ1_S
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 50; // Block stride: 50 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
let qh = bitcast<u32>(vec2(block.qh[ib], 0.0));
|
||||
let dl = d * (2 * f32((qh >> 12) & 7) + 1);
|
||||
let qh = load_u32_at(&src, block_byte_base + 34 + ib * 2) & 0xFFFF;
|
||||
let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0);
|
||||
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
|
||||
let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1]));
|
||||
let qs_w = load_u32_at(&src, block_byte_base + 2 + ib * 4);
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
@@ -560,12 +595,12 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
|
||||
#ifdef IQ4_NL
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 32;
|
||||
var qs: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
|
||||
qs[i] = load_u32_at(&src, block_byte_base + 2 + i * 4);
|
||||
}
|
||||
for (var j: u32 = 0; j < 16; j++) {
|
||||
let qsb = get_byte(qs[j / 4], j % 4);
|
||||
@@ -579,8 +614,8 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
#ifdef IQ4_XS
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0));
|
||||
let d = unpack2x16float(block.d_scales_h)[0];
|
||||
let scales_h = block.d_scales_h >> 16;
|
||||
var dst_i = dst_base + offset * 256;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4);
|
||||
|
||||
@@ -20,11 +20,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
|
||||
#ifdef Q4_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q4_0 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q4_0.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var sum: f32 = 0.0;
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1]));
|
||||
let q_byte_offset = block_byte_base + 2 + j * 4;
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
|
||||
@@ -61,12 +62,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
|
||||
#ifdef Q5_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q5_0 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q5_0.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var sum: f32 = 0.0;
|
||||
let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));
|
||||
let qh_packed = load_u32_at(&src0, block_byte_base + 2);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));
|
||||
let q_byte_offset = block_byte_base + 6 + j * 4;
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
|
||||
@@ -107,12 +109,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
|
||||
#ifdef Q8_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q8_0 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q8_0.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var sum: f32 = 0.0;
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte_offset = block_byte_base + 2 + j * 4;
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d;
|
||||
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
|
||||
@@ -178,31 +181,37 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
#ifdef Q3_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
|
||||
|
||||
// Bytes 108-109: f16 scale 'd'
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base + 108);
|
||||
|
||||
// extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
|
||||
// and 2-bits from the last 4 bytes
|
||||
// Bytes 96-107: 12 bytes of scales (3 u32s)
|
||||
let kmask1: u32 = 0x03030303;
|
||||
let kmask2: u32 = 0x0f0f0f0f;
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
|
||||
}
|
||||
scale_vals[0] = load_u32_at(&src0, block_byte_base + 96);
|
||||
scale_vals[1] = load_u32_at(&src0, block_byte_base + 100);
|
||||
scale_vals[2] = load_u32_at(&src0, block_byte_base + 104);
|
||||
|
||||
var tmp: u32 = scale_vals[2];
|
||||
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
|
||||
scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
|
||||
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
|
||||
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
||||
|
||||
// convert arrays of f16 -> u32
|
||||
// Bytes 0-31: 32 bytes of hmask (8 u32s)
|
||||
var hmask_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0; i < 8; i++) {
|
||||
hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));
|
||||
hmask_vals[i] = load_u32_at(&src0, block_byte_base + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 32-95: 64 bytes of qs (16 u32s)
|
||||
var qs_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));
|
||||
for (var i: u32 = 0u; i < 16; i++) {
|
||||
qs_vals[i] = load_u32_at(&src0, block_byte_base + 32 + i * 4);
|
||||
}
|
||||
|
||||
var sum = 0.0;
|
||||
@@ -301,21 +310,27 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
#ifdef Q6_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes
|
||||
|
||||
// convert arrays of f16 -> u32
|
||||
// Bytes 208-209: f16 scale 'd'
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base + 208);
|
||||
|
||||
// Bytes 0-127: 128 bytes of ql (32 u32s)
|
||||
var ql_vals: array<u32, 32>;
|
||||
for (var i: u32 = 0; i < 32; i++) {
|
||||
ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));
|
||||
ql_vals[i] = load_u32_at(&src0, block_byte_base + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 128-191: 64 bytes of qh (16 u32s)
|
||||
var qh_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));
|
||||
qh_vals[i] = load_u32_at(&src0, block_byte_base + 128 + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 192-207: 16 bytes of scales (4 u32s)
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
|
||||
scale_vals[i] = load_u32_at(&src0, block_byte_base + 192 + i * 4);
|
||||
}
|
||||
|
||||
var sum = 0.0;
|
||||
@@ -358,13 +373,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
|
||||
#ifdef IQ2_XXS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 32; ib += 4) {
|
||||
let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1]));
|
||||
let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3]));
|
||||
let aux0_offset = block_byte_base + 2 + ib * 2;
|
||||
let aux1_offset = block_byte_base + 2 + (ib + 2) * 2;
|
||||
let aux0 = load_u32_at(&src0, aux0_offset);
|
||||
let aux1 = load_u32_at(&src0, aux1_offset);
|
||||
let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = get_byte(aux0, l) * 8;
|
||||
@@ -384,13 +401,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
|
||||
#ifdef IQ2_XS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
|
||||
var scale_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.scales[0], block.scales[1])),
|
||||
bitcast<u32>(vec2(block.scales[2], block.scales[3]))
|
||||
load_u32_at(&src0, block_byte_base + 66),
|
||||
load_u32_at(&src0, block_byte_base + 70)
|
||||
);
|
||||
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 32; ib += 4) {
|
||||
let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
|
||||
@@ -399,7 +418,8 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
d * (0.5 + f32(s >> 4)) * 0.25
|
||||
);
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0));
|
||||
let qs_offset = block_byte_base + 2 + (ib + l) * 2;
|
||||
let qs_val = load_u32_at(&src0, qs_offset) & 0xFFFF;
|
||||
let ig = (qs_val & 511) * 8;
|
||||
let is = qs_val >> 9;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
@@ -418,21 +438,23 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
|
||||
#ifdef IQ2_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
|
||||
var qs_vals : array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
|
||||
qs_vals[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4);
|
||||
}
|
||||
var qh_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.qh[0], block.qh[1])),
|
||||
bitcast<u32>(vec2(block.qh[2], block.qh[3]))
|
||||
);
|
||||
var scale_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.scales[0], block.scales[1])),
|
||||
bitcast<u32>(vec2(block.scales[2], block.scales[3]))
|
||||
);
|
||||
|
||||
var qh_vals: array<u32, 2>;
|
||||
qh_vals[0] = load_u32_at(&src0, block_byte_base + 66);
|
||||
qh_vals[1] = load_u32_at(&src0, block_byte_base + 70);
|
||||
|
||||
var scale_vals: array<u32, 2>;
|
||||
scale_vals[0] = load_u32_at(&src0, block_byte_base + 74);
|
||||
scale_vals[1] = load_u32_at(&src0, block_byte_base + 78);
|
||||
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 8; ib ++) {
|
||||
let s = get_byte(scale_vals[ib / 4], ib % 4);
|
||||
@@ -460,17 +482,18 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
|
||||
#ifdef IQ3_XXS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 16; ib += 2) {
|
||||
let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33]));
|
||||
let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2;
|
||||
let sc_sign = load_u32_at(&src0, sc_sign_offset);
|
||||
let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let is = (sc_sign >> (7 * l)) & 127;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0));
|
||||
let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
|
||||
let ig1 = get_byte(ig_val, 0);
|
||||
let ig2 = get_byte(ig_val, 1);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
@@ -491,18 +514,22 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
|
||||
#ifdef IQ3_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
|
||||
var qh_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.qh[0], block.qh[1])),
|
||||
bitcast<u32>(vec2(block.qh[2], block.qh[3]))
|
||||
load_u32_at(&src0, block_byte_base + 66),
|
||||
load_u32_at(&src0, block_byte_base + 70)
|
||||
);
|
||||
|
||||
var sign_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0; i < 8; i++) {
|
||||
sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1]));
|
||||
sign_vals[i] = load_u32_at(&src0, block_byte_base + 74 + i * 4);
|
||||
}
|
||||
var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1]));
|
||||
|
||||
var scale_vals = load_u32_at(&src0, block_byte_base + 106);
|
||||
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 4; ib++) {
|
||||
let s = get_byte(scale_vals, ib);
|
||||
@@ -516,7 +543,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let sign_w = sign_vals[ib * 2 + k];
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let signs = get_byte(sign_w, l);
|
||||
let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0));
|
||||
let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
|
||||
let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
|
||||
let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
@@ -538,15 +565,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
|
||||
#ifdef IQ1_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
let qh = bitcast<u32>(vec2(block.qh[ib], 0.0));
|
||||
let dl = d * (2 * f32((qh >> 12) & 7) + 1);
|
||||
let qh = load_u32_at(&src0, block_byte_base + 34 + ib * 2) & 0xFFFF;
|
||||
let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0);
|
||||
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
|
||||
let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1]));
|
||||
let qs_w = load_u32_at(&src0, block_byte_base + 2 + ib * 4);
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
@@ -610,13 +637,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
|
||||
#ifdef IQ4_NL
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 32;
|
||||
var sum = 0.0;
|
||||
var qs: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
|
||||
qs[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4);
|
||||
}
|
||||
for (var j: u32 = 0; j < 16; j++) {
|
||||
let qsb = get_byte(qs[j / 4], j % 4);
|
||||
@@ -631,8 +658,8 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
#ifdef IQ4_XS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0));
|
||||
let d = unpack2x16float(block.d_scales_h)[0];
|
||||
let scales_h = block.d_scales_h >> 16;
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
|
||||
@@ -84,11 +84,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
@@ -125,12 +125,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
let m = load_f16_at(&src0, block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_lo = f16(q_byte & 0xF) * d + m;
|
||||
@@ -171,12 +171,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 2u);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
let qh_packed = load_u32_at(&src0, block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
@@ -225,14 +225,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 4u);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
let m = load_f16_at(&src0, block_byte_base + 2u);
|
||||
let qh_packed = load_u32_at(&src0, block_byte_base + 4u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
|
||||
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
@@ -277,11 +277,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
|
||||
@@ -317,12 +317,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
let m = load_f16_at(&src0, block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
|
||||
@@ -359,8 +359,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base + 80u);
|
||||
let dmin = load_src0_f16_at(block_byte_base + 82u);
|
||||
let d = load_f16_at(&src0, block_byte_base + 80u);
|
||||
let dmin = load_f16_at(&src0, block_byte_base + 82u);
|
||||
|
||||
// Decode the element at position k_in_block
|
||||
let block_of_32 = k_in_block / 32u;
|
||||
@@ -373,14 +373,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
let is = k_in_block / 16u;
|
||||
|
||||
let sc_packed = load_src0_u32_at(block_byte_base + 4u * (is / 4u));
|
||||
let sc_packed = load_u32_at(&src0, block_byte_base + 4u * (is / 4u));
|
||||
let sc = get_byte(sc_packed, is % 4u);
|
||||
|
||||
let dl = d * f16(sc & 0xFu);
|
||||
let ml = dmin * f16(sc >> 4u);
|
||||
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
|
||||
let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u));
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
let qs_val = (q_byte >> shift) & 3u;
|
||||
|
||||
@@ -413,7 +413,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base + 108u);
|
||||
let d = load_f16_at(&src0, block_byte_base + 108u);
|
||||
|
||||
// Load and unpack scales
|
||||
let kmask1: u32 = 0x03030303u;
|
||||
@@ -421,7 +421,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0u; i < 4u; i++) {
|
||||
scale_vals[i] = load_src0_u32_at(block_byte_base + 96u + 4u * i);
|
||||
scale_vals[i] = load_u32_at(&src0, block_byte_base + 96u + 4u * i);
|
||||
}
|
||||
|
||||
var tmp: u32 = scale_vals[2];
|
||||
@@ -433,12 +433,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
// Load hmask and qs arrays
|
||||
var hmask_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0u; i < 8u; i++) {
|
||||
hmask_vals[i] = load_src0_u32_at(block_byte_base + 4u * i);
|
||||
hmask_vals[i] = load_u32_at(&src0, block_byte_base + 4u * i);
|
||||
}
|
||||
|
||||
var qs_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0u; i < 16u; i++) {
|
||||
qs_vals[i] = load_src0_u32_at(block_byte_base + 32u + 4u * i);
|
||||
qs_vals[i] = load_u32_at(&src0, block_byte_base + 32u + 4u * i);
|
||||
}
|
||||
|
||||
let half = k_in_block / 128u; // 0 or 1
|
||||
@@ -499,13 +499,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let dmin = load_src0_f16_at(block_byte_base + 2u);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
let dmin = load_f16_at(&src0, block_byte_base + 2u);
|
||||
|
||||
// Load packed scales
|
||||
var scale_vals: array<u32, 3>;
|
||||
for (var i: u32 = 0u; i < 3u; i++) {
|
||||
scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
|
||||
scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i);
|
||||
}
|
||||
|
||||
// Map k_in_block to loop structure:
|
||||
@@ -541,7 +541,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let ml = dmin * f16(mn);
|
||||
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
|
||||
let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u));
|
||||
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
let qs_val = (q_byte >> shift) & 0xFu;
|
||||
@@ -575,13 +575,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let dmin = load_src0_f16_at(block_byte_base + 2u);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
let dmin = load_f16_at(&src0, block_byte_base + 2u);
|
||||
|
||||
// Load packed scales
|
||||
var scale_vals: array<u32, 3>;
|
||||
for (var i: u32 = 0u; i < 3u; i++) {
|
||||
scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
|
||||
scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i);
|
||||
}
|
||||
|
||||
// The original loop processes elements in groups of 64
|
||||
@@ -621,11 +621,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let ml = dmin * f16(mn);
|
||||
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_packed = load_src0_u32_at(block_byte_base + 48u + 4u * (q_idx / 4u));
|
||||
let q_packed = load_u32_at(&src0, block_byte_base + 48u + 4u * (q_idx / 4u));
|
||||
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (l / 4u));
|
||||
let qh_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (l / 4u));
|
||||
|
||||
let qh_byte = get_byte(qh_packed, l % 4u);
|
||||
|
||||
@@ -673,17 +673,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
// Load only ql13 word needed
|
||||
let ql13_flat = ql_b_idx + l;
|
||||
let ql13 = load_src0_u32_at(block_byte_base + ql13_flat);
|
||||
let ql13 = load_u32_at(&src0, block_byte_base + ql13_flat);
|
||||
let ql13_b = get_byte(ql13, 0u);
|
||||
|
||||
// Load only ql24 word needed
|
||||
let ql24_flat = ql_b_idx + l + 32u;
|
||||
let ql24 = load_src0_u32_at(block_byte_base + ql24_flat);
|
||||
let ql24 = load_u32_at(&src0, block_byte_base + ql24_flat);
|
||||
let ql24_b = get_byte(ql24, 0u);
|
||||
|
||||
// Load only qh word needed
|
||||
let qh_flat = qh_b_idx + l;
|
||||
let qh = load_src0_u32_at(block_byte_base + 128u + qh_flat);
|
||||
let qh = load_u32_at(&src0, block_byte_base + 128u + qh_flat);
|
||||
let qh_b = get_byte(qh, 0u);
|
||||
|
||||
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
|
||||
@@ -694,10 +694,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
// Load only the scale word needed
|
||||
let is = l / 16u;
|
||||
let sc_idx = sc_b_idx + is + quarter * 2u;
|
||||
let sc = load_src0_u32_at(block_byte_base + 192u + sc_idx);
|
||||
let sc = load_u32_at(&src0, block_byte_base + 192u + sc_idx);
|
||||
let sc_val = get_byte_i32(sc, 0u);
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base + 208u);
|
||||
let d = load_f16_at(&src0, block_byte_base + 208u);
|
||||
|
||||
var q_val: f16;
|
||||
if (quarter == 0u) {
|
||||
|
||||
@@ -65,10 +65,10 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let d = f32(load_f16_at(&src0, block_byte_base));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
@@ -98,11 +98,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let m = f32(load_src0_f16_at(block_byte_base + 2u));
|
||||
let d = f32(load_f16_at(&src0, block_byte_base));
|
||||
let m = f32(load_f16_at(&src0, block_byte_base + 2u));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
|
||||
@@ -132,12 +132,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 2u);
|
||||
let d = f32(load_f16_at(&src0, block_byte_base));
|
||||
let qh_packed = load_u32_at(&src0, block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
@@ -176,13 +176,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 4u);
|
||||
let d = f32(load_f16_at(&src0, block_byte_base));
|
||||
let m = load_f16_at(&src0, block_byte_base + 2u);
|
||||
let qh_packed = load_u32_at(&src0, block_byte_base + 4u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
@@ -221,11 +221,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let d = f32(load_f16_at(&src0, block_byte_base));
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d;
|
||||
@@ -254,12 +254,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
let d = f32(load_f16_at(&src0, block_byte_base));
|
||||
let m = load_f16_at(&src0, block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d + f32(m);
|
||||
@@ -309,13 +309,13 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
for (var i = ix; i < nb; i += 2u) {
|
||||
let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = f32(load_src0_f16_at(bbase + 208u));
|
||||
let d = f32(load_f16_at(&src0, bbase + 208u));
|
||||
|
||||
let ql1_u32 = load_src0_u32_at(bbase + q_offset_l);
|
||||
let ql2_u32 = load_src0_u32_at(bbase + q_offset_l + 32u);
|
||||
let qh_u32 = load_src0_u32_at(bbase + 128u + q_offset_h);
|
||||
let sc_u32_0 = load_src0_u32_at(bbase + sc_base_byte);
|
||||
let sc_u32_1 = load_src0_u32_at(bbase + sc_base_byte + 4u);
|
||||
let ql1_u32 = load_u32_at(&src0, bbase + q_offset_l);
|
||||
let ql2_u32 = load_u32_at(&src0, bbase + q_offset_l + 32u);
|
||||
let qh_u32 = load_u32_at(&src0, bbase + 128u + q_offset_h);
|
||||
let sc_u32_0 = load_u32_at(&src0, bbase + sc_base_byte);
|
||||
let sc_u32_1 = load_u32_at(&src0, bbase + sc_base_byte + 4u);
|
||||
|
||||
let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
|
||||
let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);
|
||||
|
||||
@@ -107,7 +107,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let res = src[params.offset_src + src_idx] / (1.0 + exp(-src[params.offset_src + src_idx]));
|
||||
#endif
|
||||
#ifdef EXP
|
||||
let res = exp(src[params.offset_src + src_idx]);
|
||||
let src_f32 = f32(src[params.offset_src + src_idx]);
|
||||
let res = TYPE(exp(src_f32));
|
||||
#endif
|
||||
#ifdef LOG
|
||||
let res = TYPE(log(f32(src[params.offset_src + src_idx])));
|
||||
@@ -161,7 +162,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let res = TYPE(select(log(1.0 + exp(src_f32)), src_f32, src_f32 > 20.0));
|
||||
#endif
|
||||
#ifdef EXPM1
|
||||
let res = exp(src[params.offset_src + src_idx]) - 1.0;
|
||||
let src_f32 = f32(src[params.offset_src + src_idx]);
|
||||
let res = TYPE(exp(src_f32) - 1.0);
|
||||
#endif
|
||||
#ifdef FLOOR
|
||||
let res = floor(src[params.offset_src + src_idx]);
|
||||
@@ -181,7 +183,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let res = src[params.offset_src + src_idx] * src[params.offset_src + src_idx];
|
||||
#endif
|
||||
#ifdef SQRT
|
||||
let res = sqrt(src[params.offset_src + src_idx]);
|
||||
let res = TYPE(sqrt(f32(src[params.offset_src + src_idx])));
|
||||
#endif
|
||||
#ifdef SIN
|
||||
let res_f32 = sin(f32(src[params.offset_src + src_idx]));
|
||||
|
||||
@@ -152,14 +152,14 @@
|
||||
|
||||
{%- set ns = namespace(prev_message_type=None, last_user_message=-1) -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{{ bos_token }}
|
||||
{{- bos_token -}}
|
||||
{#- Handle System/Tool Definitions Block -#}
|
||||
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- '<|turn>system\n' -}}
|
||||
|
||||
{#- Inject Thinking token at the very top of the FIRST system turn -#}
|
||||
{%- if enable_thinking is defined and enable_thinking -%}
|
||||
{{- '<|think|>' -}}
|
||||
{{- '<|think|>\n' -}}
|
||||
{%- set ns.prev_message_type = 'think' -%}
|
||||
{%- endif -%}
|
||||
|
||||
@@ -255,13 +255,13 @@
|
||||
{{- item['text'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif item['type'] == 'image' -%}
|
||||
{{- '\n\n<|image|>\n\n' -}}
|
||||
{{- '<|image|>' -}}
|
||||
{%- set ns.prev_message_type = 'image' -%}
|
||||
{%- elif item['type'] == 'audio' -%}
|
||||
{{- '<|audio|>' -}}
|
||||
{%- set ns.prev_message_type = 'audio' -%}
|
||||
{%- elif item['type'] == 'video' -%}
|
||||
{{- '\n\n<|video|>\n\n' -}}
|
||||
{{- '<|video|>' -}}
|
||||
{%- set ns.prev_message_type = 'video' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
@@ -11,34 +11,15 @@
|
||||
description:<|"|>{{ value['description'] }}<|"|>
|
||||
{%- set add_comma = true -%}
|
||||
{%- endif -%}
|
||||
{%- if value['nullable'] %}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
nullable:true
|
||||
{%- endif -%}
|
||||
{%- if value['type'] | upper == 'STRING' -%}
|
||||
{%- if value['enum'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
enum:{{ format_argument(value['enum']) }}
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'OBJECT' -%}
|
||||
,properties:{
|
||||
{%- if value['properties'] is defined and value['properties'] is mapping -%}
|
||||
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
|
||||
{%- elif value is mapping -%}
|
||||
{{- format_parameters(value, value['required'] | default([])) -}}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- if value['required'] -%}
|
||||
,required:[
|
||||
{%- for item in value['required'] | default([]) -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'ARRAY' -%}
|
||||
{%- if value['items'] is mapping and value['items'] -%}
|
||||
,items:{
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
items:{
|
||||
{%- set ns_items = namespace(found_first=false) -%}
|
||||
{%- for item_key, item_value in value['items'] | dictsort -%}
|
||||
{%- if item_value is not none -%}
|
||||
@@ -71,6 +52,32 @@
|
||||
}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if value['nullable'] %}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
nullable:true
|
||||
{%- endif -%}
|
||||
{%- if value['type'] | upper == 'OBJECT' -%}
|
||||
{%- if value['properties'] is defined and value['properties'] is mapping -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
properties:{
|
||||
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
|
||||
}
|
||||
{%- elif value is mapping -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
properties:{
|
||||
{{- format_parameters(value, value['required'] | default([])) -}}
|
||||
}
|
||||
{%- endif -%}
|
||||
{%- if value['required'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
required:[
|
||||
{%- for item in value['required'] | default([]) -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
type:<|"|>{{ value['type'] | upper }}<|"|>}
|
||||
{%- endif -%}
|
||||
@@ -150,16 +157,31 @@
|
||||
{{- ns.result | trim -}}
|
||||
{%- endmacro -%}
|
||||
|
||||
{%- macro format_tool_response_block(tool_name, response) -%}
|
||||
{{- '<|tool_response>' -}}
|
||||
{%- if response is mapping -%}
|
||||
{{- 'response:' + tool_name + '{' -}}
|
||||
{%- for key, value in response | dictsort -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- else -%}
|
||||
{{- 'response:' + tool_name + '{value:' + format_argument(response, escape_keys=False) + '}' -}}
|
||||
{%- endif -%}
|
||||
{{- '<tool_response|>' -}}
|
||||
{%- endmacro -%}
|
||||
|
||||
{%- set ns = namespace(prev_message_type=None) -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{{ bos_token }}
|
||||
{{- bos_token -}}
|
||||
{#- Handle System/Tool Definitions Block -#}
|
||||
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- '<|turn>system\n' -}}
|
||||
|
||||
{#- Inject Thinking token at the very top of the FIRST system turn -#}
|
||||
{%- if enable_thinking is defined and enable_thinking -%}
|
||||
{{- '<|think|>' -}}
|
||||
{{- '<|think|>\n' -}}
|
||||
{%- set ns.prev_message_type = 'think' -%}
|
||||
{%- endif -%}
|
||||
|
||||
@@ -180,11 +202,41 @@
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif %}
|
||||
|
||||
{#- Pre-scan: find last user message index for reasoning guard -#}
|
||||
{%- set ns_turn = namespace(last_user_idx=-1) -%}
|
||||
{%- for i in range(loop_messages | length) -%}
|
||||
{%- if loop_messages[i]['role'] == 'user' -%}
|
||||
{%- set ns_turn.last_user_idx = i -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{#- Loop through messages -#}
|
||||
{%- for message in loop_messages -%}
|
||||
{%- if message['role'] != 'tool' -%}
|
||||
{%- set ns.prev_message_type = None -%}
|
||||
{%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}
|
||||
{#- Detect continuation: suppress duplicate <|turn>model when previous non-tool message was also assistant -#}
|
||||
{%- set prev_nt = namespace(role=None, found=false) -%}
|
||||
{%- if loop.index0 > 0 -%}
|
||||
{%- for j in range(loop.index0 - 1, -1, -1) -%}
|
||||
{%- if not prev_nt.found -%}
|
||||
{%- if loop_messages[j]['role'] != 'tool' -%}
|
||||
{%- set prev_nt.role = loop_messages[j]['role'] -%}
|
||||
{%- set prev_nt.found = true -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
{%- set continue_same_model_turn = (role == 'model' and prev_nt.role == 'assistant') -%}
|
||||
{%- if not continue_same_model_turn -%}
|
||||
{{- '<|turn>' + role + '\n' }}
|
||||
{%- endif -%}
|
||||
|
||||
{#- Render reasoning/reasoning_content as thinking channel -#}
|
||||
{%- set thinking_text = message.get('reasoning') or message.get('reasoning_content') -%}
|
||||
{%- if thinking_text and loop.index0 > ns_turn.last_user_idx and message.get('tool_calls') -%}
|
||||
{{- '<|channel>thought\n' + thinking_text + '\n<channel|>' -}}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['tool_calls'] -%}
|
||||
{%- for tool_call in message['tool_calls'] -%}
|
||||
@@ -205,23 +257,49 @@
|
||||
{%- set ns.prev_message_type = 'tool_call' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['tool_responses'] -%}
|
||||
{#- Tool Response handling -#}
|
||||
{%- set ns_tr_out = namespace(flag=false) -%}
|
||||
{%- if message.get('tool_responses') -%}
|
||||
{#- Legacy: tool_responses embedded on the assistant message (Google/Gemma native) -#}
|
||||
{%- for tool_response in message['tool_responses'] -%}
|
||||
{{- '<|tool_response>' -}}
|
||||
{%- if tool_response['response'] is mapping -%}
|
||||
{{- 'response:' + tool_response['name'] | default('unknown') + '{' -}}
|
||||
{%- for key, value in tool_response['response'] | dictsort -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- else -%}
|
||||
{{- 'response:' + tool_response['name'] | default('unknown') + '{value:' + format_argument(tool_response['response'], escape_keys=False) + '}' -}}
|
||||
{%- endif -%}
|
||||
{{- '<tool_response|>' -}}
|
||||
{{- format_tool_response_block(tool_response['name'] | default('unknown'), tool_response['response']) -}}
|
||||
{%- set ns_tr_out.flag = true -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endfor -%}
|
||||
{%- elif message.get('tool_calls') -%}
|
||||
{#- OpenAI Chat Completions: forward-scan consecutive role:tool messages -#}
|
||||
{%- set ns_tool_scan = namespace(stopped=false) -%}
|
||||
{%- for k in range(loop.index0 + 1, loop_messages | length) -%}
|
||||
{%- if ns_tool_scan.stopped -%}
|
||||
{%- elif loop_messages[k]['role'] != 'tool' -%}
|
||||
{%- set ns_tool_scan.stopped = true -%}
|
||||
{%- else -%}
|
||||
{%- set follow = loop_messages[k] -%}
|
||||
{#- Resolve tool_call_id to function name -#}
|
||||
{%- set ns_tname = namespace(name=follow.get('name') | default('unknown')) -%}
|
||||
{%- for tc in message['tool_calls'] -%}
|
||||
{%- if tc.get('id') == follow.get('tool_call_id') -%}
|
||||
{%- set ns_tname.name = tc['function']['name'] -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{#- Handle content as string or content-parts array -#}
|
||||
{%- set tool_body = follow.get('content') -%}
|
||||
{%- if tool_body is string -%}
|
||||
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
|
||||
{%- elif tool_body is sequence and tool_body is not string -%}
|
||||
{%- set ns_txt = namespace(s='') -%}
|
||||
{%- for part in tool_body -%}
|
||||
{%- if part.get('type') == 'text' -%}
|
||||
{%- set ns_txt.s = ns_txt.s + (part.get('text') | default('')) -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- format_tool_response_block(ns_tname.name, ns_txt.s) -}}
|
||||
{%- else -%}
|
||||
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
|
||||
{%- endif -%}
|
||||
{%- set ns_tr_out.flag = true -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['content'] is string -%}
|
||||
@@ -239,28 +317,31 @@
|
||||
{{- item['text'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif item['type'] == 'image' -%}
|
||||
{{- '\n\n<|image|>\n\n' -}}
|
||||
{{- '<|image|>' -}}
|
||||
{%- set ns.prev_message_type = 'image' -%}
|
||||
{%- elif item['type'] == 'audio' -%}
|
||||
{{- '<|audio|>' -}}
|
||||
{%- set ns.prev_message_type = 'audio' -%}
|
||||
{%- elif item['type'] == 'video' -%}
|
||||
{{- '\n\n<|video|>\n\n' -}}
|
||||
{{- '<|video|>' -}}
|
||||
{%- set ns.prev_message_type = 'video' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if not (message['tool_responses'] and not message['content']) -%}
|
||||
{%- if ns.prev_message_type == 'tool_call' and not ns_tr_out.flag -%}
|
||||
{{- '<|tool_response>' -}}
|
||||
{%- elif not (ns_tr_out.flag and not message.get('content')) -%}
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if add_generation_prompt -%}
|
||||
{%- if ns.prev_message_type != 'tool_response' -%}
|
||||
{%- if ns.prev_message_type != 'tool_response' and ns.prev_message_type != 'tool_call' -%}
|
||||
{{- '<|turn>model\n' -}}
|
||||
{%- endif -%}
|
||||
{%- if not enable_thinking | default(false) -%}
|
||||
{{- '<|channel>thought\n<channel|>' -}}
|
||||
{%- if not enable_thinking | default(false) -%}
|
||||
{{- '<|channel>thought\n<channel|>' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
|
||||
@@ -22,9 +22,6 @@ device="HTP0"
|
||||
verbose=
|
||||
[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" cli_opts="$cli_opts -v"
|
||||
|
||||
experimental=
|
||||
[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E"
|
||||
|
||||
profile=
|
||||
[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" cli_opts="$cli_opts -v"
|
||||
|
||||
@@ -46,7 +43,7 @@ adb $adbserial $adbhost shell " \
|
||||
cd $basedir; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$ndev $nhvx $opmask $verbose $experimental $profile $hb ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \
|
||||
$ndev $nhvx $opmask $verbose $profile $hb ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ubatch-size 256 -fa 1 -ngl 99 $cli_opts $@ \
|
||||
"
|
||||
|
||||
@@ -21,9 +21,6 @@ model="Llama-3.2-3B-Instruct-Q4_0.gguf"
|
||||
device="HTP0"
|
||||
[ "$D" != "" ] && device="$D"
|
||||
|
||||
experimental=
|
||||
[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E"
|
||||
|
||||
verbose=
|
||||
[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" cli_opts="$cli_opts -v"
|
||||
|
||||
@@ -48,13 +45,22 @@ ndev=
|
||||
hb=
|
||||
[ "$HB" != "" ] && hb="GGML_HEXAGON_HOSTBUF=$HB"
|
||||
|
||||
opbatch=
|
||||
[ "$OB" != "" ] && opbatch="GGML_HEXAGON_OPBATCH=$OB"
|
||||
|
||||
opqueue=
|
||||
[ "$OQ" != "" ] && opqueue="GGML_HEXAGON_OPQUEUE=$OQ"
|
||||
|
||||
opflt=
|
||||
[ "$OF" != "" ] && opflt="GGML_HEXAGON_OPFILTER=$OF"
|
||||
|
||||
set -x
|
||||
|
||||
adb $adbserial $adbhost shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $hmx $ndev $hb \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $opflt \
|
||||
./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --ubatch-size 256 -fa on \
|
||||
|
||||
@@ -21,9 +21,6 @@ model="Llama-3.2-3B-Instruct-Q4_0.gguf"
|
||||
device="HTP0"
|
||||
[ "$D" != "" ] && device="$D"
|
||||
|
||||
experimental=
|
||||
[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E"
|
||||
|
||||
verbose=
|
||||
[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" cli_opts="$cli_opts -v"
|
||||
|
||||
@@ -48,13 +45,22 @@ ndev=
|
||||
hb=
|
||||
[ "$HB" != "" ] && hb="GGML_HEXAGON_HOSTBUF=$HB"
|
||||
|
||||
opbatch=
|
||||
[ "$OB" != "" ] && opbatch="GGML_HEXAGON_OPBATCH=$OB"
|
||||
|
||||
opqueue=
|
||||
[ "$OQ" != "" ] && opqueue="GGML_HEXAGON_OPQUEUE=$OQ"
|
||||
|
||||
opflt=
|
||||
[ "$OF" != "" ] && opflt="GGML_HEXAGON_OPFILTER=$OF"
|
||||
|
||||
set -x
|
||||
|
||||
adb $adbserial $adbhost shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $hmx $ndev $hb \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $opflt \
|
||||
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --ubatch-size 256 -fa on \
|
||||
|
||||
@@ -21,9 +21,6 @@ device="HTP0"
|
||||
verbose=
|
||||
[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V"
|
||||
|
||||
experimental=
|
||||
[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E"
|
||||
|
||||
sched=
|
||||
[ "$SCHED" != "" ] && sched="GGML_SCHED_DEBUG=2" cli_opts="$cli_opts -v"
|
||||
|
||||
@@ -53,5 +50,5 @@ adb $adbserial $adbhost shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $hmx $ndev $hb ./$branch/bin/$tool $@ \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb ./$branch/bin/$tool $@ \
|
||||
"
|
||||
|
||||
@@ -20,10 +20,6 @@ if ($null -ne $env:V) {
|
||||
$env:GGML_HEXAGON_VERBOSE=$env:V
|
||||
}
|
||||
|
||||
if ($null -ne $env:E) {
|
||||
$env:GGML_HEXAGON_EXPERIMENTAL=$env:E
|
||||
}
|
||||
|
||||
if ($null -ne $env:PROF) {
|
||||
$env:GGML_HEXAGON_PROFILE=$env:PROF; $env:GGML_HEXAGON_OPSYNC=1
|
||||
}
|
||||
|
||||
@@ -20,10 +20,6 @@ if ($null -ne $env:V) {
|
||||
$env:GGML_HEXAGON_VERBOSE=$env:V
|
||||
}
|
||||
|
||||
if ($null -ne $env:E) {
|
||||
$env:GGML_HEXAGON_EXPERIMENTAL=$env:E
|
||||
}
|
||||
|
||||
if ($null -ne $env:SCHED) {
|
||||
$env:GGML_SCHED_DEBUG=$env:SCHED; $cli_opts="$cli_opts -v"
|
||||
}
|
||||
|
||||
@@ -20,10 +20,6 @@ if ($null -ne $env:V) {
|
||||
$env:GGML_HEXAGON_VERBOSE=$env:V
|
||||
}
|
||||
|
||||
if ($null -ne $env:E) {
|
||||
$env:GGML_HEXAGON_EXPERIMENTAL=$env:E
|
||||
}
|
||||
|
||||
if ($null -ne $env:SCHED) {
|
||||
$env:GGML_SCHED_DEBUG=$env:SCHED; $cli_opts="$cli_opts -v"
|
||||
}
|
||||
|
||||
@@ -29,12 +29,6 @@ if ($null -ne $env:V) {
|
||||
$env:GGML_HEXAGON_VERBOSE=$env:V
|
||||
}
|
||||
|
||||
# Default experimental to 1
|
||||
$env:GGML_HEXAGON_EXPERIMENTAL=1
|
||||
if ($null -ne $env:E) {
|
||||
$env:GGML_HEXAGON_EXPERIMENTAL=$env:E
|
||||
}
|
||||
|
||||
if ($null -ne $env:SCHED) {
|
||||
$env:GGML_SCHED_DEBUG=$env:SCHED; $cli_opts="$cli_opts -v"
|
||||
}
|
||||
|
||||
@@ -26,10 +26,6 @@ if ($null -ne $env:V) {
|
||||
$env:GGML_HEXAGON_VERBOSE=$env:V
|
||||
}
|
||||
|
||||
if ($null -ne $env:E) {
|
||||
$env:GGML_HEXAGON_EXPERIMENTAL=$env:E
|
||||
}
|
||||
|
||||
if ($null -ne $env:SCHED) {
|
||||
$env:GGML_SCHED_DEBUG=$env:SCHED; $cli_opts="$cli_opts -v"
|
||||
}
|
||||
|
||||
@@ -4623,17 +4623,18 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k(i);
|
||||
const int64_t n_embd_k = hparams.n_embd_k_gqa(i);
|
||||
const int64_t n_embd_v = hparams.n_embd_v_gqa(i);
|
||||
const int kv_flags = hparams.has_kv(i) ? 0 : TENSOR_NOT_REQUIRED;
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
// note: use_alternative_attention (v_proj is optional, if it's not present, use k_proj)
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head * n_head}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k}, kv_flags);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v}, TENSOR_NOT_REQUIRED);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head * n_head, n_embd}, 0);
|
||||
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head}, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head}, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head}, kv_flags);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1u}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
@@ -1014,7 +1014,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
model.hf_file = params.hf_file[i];
|
||||
}
|
||||
|
||||
auto download_result = common_download_model(model, params.hf_token);
|
||||
common_download_opts opts;
|
||||
opts.bearer_token = params.hf_token;
|
||||
auto download_result = common_download_model(model, opts);
|
||||
if (download_result.model_path.empty()) {
|
||||
fprintf(stderr, "error: failed to download model from HuggingFace\n");
|
||||
exit(1);
|
||||
|
||||
@@ -98,6 +98,7 @@ static void unset_reserved_args(common_preset & preset, bool unset_model_args) {
|
||||
if (unset_model_args) {
|
||||
preset.unset_option("LLAMA_ARG_MODEL");
|
||||
preset.unset_option("LLAMA_ARG_MMPROJ");
|
||||
preset.unset_option("LLAMA_ARG_ALIAS");
|
||||
preset.unset_option("LLAMA_ARG_HF_REPO");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user