mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-30 16:47:31 +03:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0c6ee1cade | ||
|
|
2dd84169d1 | ||
|
|
f454bd7eb8 | ||
|
|
b760272f1a | ||
|
|
dcad77cc3b | ||
|
|
98dc1418ea | ||
|
|
9725a313be | ||
|
|
d1649047a3 | ||
|
|
9d34231bb8 | ||
|
|
8ea8fee966 | ||
|
|
eddd7a13a5 | ||
|
|
dd2914dc81 | ||
|
|
0adede866d | ||
|
|
361fe72acb | ||
|
|
a702f39597 | ||
|
|
13d36cf891 | ||
|
|
f65bc34c68 | ||
|
|
15fa3c493b | ||
|
|
dc80c5252a | ||
|
|
e583f3b4f5 |
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -6,7 +6,7 @@
|
||||
|
||||
<!-- You can provide more details and link related discussions here. Delete this section if not applicable -->
|
||||
|
||||
# Requirements
|
||||
## Requirements
|
||||
|
||||
<!-- IMPORTANT: Please do NOT delete this section, otherwise your PR may be rejected -->
|
||||
|
||||
|
||||
33
.github/workflows/build-and-test-snapdragon.yml
vendored
33
.github/workflows/build-and-test-snapdragon.yml
vendored
@@ -49,28 +49,19 @@ jobs:
|
||||
cp docs/backend/snapdragon/CMakeUserPresets.json .
|
||||
cmake --preset arm64-android-snapdragon-release -B build
|
||||
cmake --build build
|
||||
cmake --install build --prefix pkg-adb/llama.cpp
|
||||
cmake --install build --prefix pkg-snapdragon/llama.cpp
|
||||
|
||||
- name: Upload Llama.CPP Snapdragon Android Build Artifact
|
||||
if: ${{ always() && steps.build_llama_cpp_snapdragon_android.outcome == 'success' }}
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: llama-cpp-android-arm64-snapdragon
|
||||
path: pkg-adb/llama.cpp
|
||||
|
||||
check-secret:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
has-key: ${{ steps.check.outputs.has-key }}
|
||||
steps:
|
||||
- id: check
|
||||
run: echo "has-key=${{ secrets.QDC_API_KEY != '' }}" >> "$GITHUB_OUTPUT"
|
||||
path: pkg-snapdragon/llama.cpp
|
||||
|
||||
test-snapdragon-qdc:
|
||||
name: Test on QDC Android Device (${{ matrix.device }})
|
||||
needs: [android-ndk-snapdragon, check-secret]
|
||||
if: needs.check-secret.outputs.has-key == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
needs: [android-ndk-snapdragon]
|
||||
runs-on: ubuntu-slim
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -81,10 +72,10 @@ jobs:
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Download build artifact
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@v7
|
||||
with:
|
||||
name: llama-cpp-android-arm64-snapdragon
|
||||
path: pkg-snapdragon/
|
||||
path: pkg-snapdragon/llama.cpp
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -92,13 +83,25 @@ jobs:
|
||||
python-version: '3.x'
|
||||
cache: pip
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y curl unzip
|
||||
|
||||
- name: Install QDC SDK wheel
|
||||
run: |
|
||||
curl -fSL -o qdc_sdk.zip https://softwarecenter.qualcomm.com/api/download/software/tools/Qualcomm_Device_Cloud_SDK/All/0.2.3/qualcomm_device_cloud_sdk-0.2.3.zip
|
||||
unzip qdc_sdk.zip -d qdc_sdk
|
||||
pip install qdc_sdk/qualcomm_device_cloud_sdk-0.2.3-py3-none-any.whl
|
||||
|
||||
- name: Check QDC API key
|
||||
id: check_secret
|
||||
env:
|
||||
QDC_API_KEY: ${{ secrets.QDC_API_KEY }}
|
||||
run: echo "has-qdc-key=${{ env.QDC_API_KEY != '' }}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Run QDC tests (${{ matrix.device }})
|
||||
if: steps.check_secret.outputs.has-qdc-key == 'true'
|
||||
run: |
|
||||
python scripts/snapdragon/qdc/run_qdc_jobs.py \
|
||||
--test all \
|
||||
|
||||
12
.gitignore
vendored
12
.gitignore
vendored
@@ -34,7 +34,6 @@
|
||||
/.vscode/
|
||||
/nppBackup
|
||||
|
||||
|
||||
# Coverage
|
||||
|
||||
/gcovr-report/
|
||||
@@ -74,6 +73,7 @@
|
||||
!/models/templates
|
||||
|
||||
# Zig
|
||||
|
||||
/zig-out/
|
||||
/zig-cache/
|
||||
|
||||
@@ -93,6 +93,7 @@
|
||||
!/examples/sycl/*.sh
|
||||
|
||||
# Server Web UI temporary files
|
||||
|
||||
/tools/server/webui/node_modules
|
||||
/tools/server/webui/dist
|
||||
# we no longer use gz for index.html
|
||||
@@ -106,9 +107,11 @@ __pycache__/
|
||||
poetry.toml
|
||||
|
||||
# Nix
|
||||
|
||||
/result
|
||||
|
||||
# Test binaries
|
||||
|
||||
/tests/test-backend-ops
|
||||
/tests/test-double-float
|
||||
/tests/test-grad0
|
||||
@@ -124,6 +127,7 @@ poetry.toml
|
||||
/tests/test-tokenizer-1-spm
|
||||
|
||||
# Scripts
|
||||
|
||||
!/scripts/install-oneapi.bat
|
||||
|
||||
# Generated by scripts
|
||||
@@ -132,18 +136,24 @@ poetry.toml
|
||||
/wikitext-2-raw/
|
||||
|
||||
# Test models for lora adapters
|
||||
|
||||
/lora-tests
|
||||
|
||||
# Local scripts
|
||||
|
||||
/run-vim.sh
|
||||
/run-chat.sh
|
||||
/run-spec.sh
|
||||
/.ccache/
|
||||
|
||||
# IDE
|
||||
|
||||
/*.code-workspace
|
||||
/.windsurf/
|
||||
# emscripten
|
||||
a.out.*
|
||||
|
||||
# AGENTS
|
||||
|
||||
AGENTS.local.md
|
||||
.pi/SYSTEM.md
|
||||
|
||||
33
.pi/gg/SYSTEM.md
Normal file
33
.pi/gg/SYSTEM.md
Normal file
@@ -0,0 +1,33 @@
|
||||
You are a coding agent. Here are some very important rules that you must follow:
|
||||
|
||||
General:
|
||||
- By very precise and concise when writing code, comments, explanations, etc.
|
||||
- PR and commit titles format: `<module> : <title>`. Lookup recents for examples
|
||||
- Don't try to build or run the code unless you are explicitly asked to do so
|
||||
|
||||
Coding:
|
||||
- When in doubt, always refer to the CONTRIBUTING.md file of the project
|
||||
- When referencing issues or PRs in comments, use the format:
|
||||
- C/C++ code: `// ref: <url>`
|
||||
- Other (CMake, etc.): `# ref: <url>`
|
||||
|
||||
Pull requests (PRs):
|
||||
- New branch names are prefixed with "gg/"
|
||||
- Before opening a pull request, ask the user to confirm the description
|
||||
- When creating a pull request, look for the repository's PR template and follow it
|
||||
- For the AI usage disclosure section, write "YES. llama.cpp + pi"
|
||||
- Always create the pull requests in draft mode
|
||||
|
||||
Commits:
|
||||
- On every commit that you make, include a "Assisted-by: llama.cpp:local pi" tag
|
||||
- Do not explicitly set the git author in commits - rely on the default git config
|
||||
|
||||
Resources (read on demand):
|
||||
- [CONTRIBUTING.md](CONTRIBUTING.md)
|
||||
- [Build documentation](docs/build.md)
|
||||
- [Server usage documentation](tools/server/README.md)
|
||||
- [Server development documentation](tools/server/README-dev.md)
|
||||
- [PEG parser](docs/development/parsing.md)
|
||||
- [Auto parser](docs/autoparser.md)
|
||||
- [Jinja engine](common/jinja/README.md)
|
||||
- [PR template](.github/pull_request_template.md)
|
||||
@@ -296,7 +296,7 @@ void analyze_reasoning::compare_reasoning_presence() {
|
||||
return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space())) + p.rest());
|
||||
});
|
||||
auto parser_wrapped = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.tag("pre", p.marker() + p.space()) + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest();
|
||||
return p.tag("pre", p.marker() + p.space()) + p.literal(reasoning_content) + p.tag("post", (p.space() + p.marker() + p.space())) + p.rest();
|
||||
});
|
||||
// try the more aggressive parse first, if it fails, fall back to the delimiter one
|
||||
auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B);
|
||||
@@ -306,11 +306,11 @@ void analyze_reasoning::compare_reasoning_presence() {
|
||||
if (result.result.success()) {
|
||||
if (!result.tags["pre"].empty() && !result.tags["post"].empty()) {
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
start = trim_leading_whitespace(result.tags["pre"]);
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
start = result.tags["pre"];
|
||||
end = result.tags["post"];
|
||||
} else if (!result.tags["post"].empty()) {
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
end = result.tags["post"];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,10 +106,16 @@ struct statement {
|
||||
size_t pos; // position in source, for debugging
|
||||
virtual ~statement() = default;
|
||||
virtual std::string type() const { return "Statement"; }
|
||||
|
||||
// execute_impl must be overridden by derived classes
|
||||
virtual value execute_impl(context &) { throw std::runtime_error("cannot exec " + type()); }
|
||||
virtual value execute_impl(context &) { throw_exec_error(); }
|
||||
// execute is the public method to execute a statement with error handling
|
||||
value execute(context &);
|
||||
|
||||
private:
|
||||
[[noreturn]] void throw_exec_error() const {
|
||||
throw std::runtime_error("cannot exec " + type());
|
||||
}
|
||||
};
|
||||
|
||||
// Type Checking Utilities
|
||||
@@ -143,7 +149,7 @@ struct program : public statement {
|
||||
program() = default;
|
||||
explicit program(statements && body) : body(std::move(body)) {}
|
||||
std::string type() const override { return "Program"; }
|
||||
value execute_impl(context &) override {
|
||||
[[noreturn]] value execute_impl(context &) override {
|
||||
throw std::runtime_error("Cannot execute program directly, use jinja::runtime instead");
|
||||
}
|
||||
};
|
||||
@@ -195,7 +201,7 @@ struct break_statement : public statement {
|
||||
}
|
||||
};
|
||||
|
||||
value execute_impl(context &) override {
|
||||
[[noreturn]] value execute_impl(context &) override {
|
||||
throw break_statement::signal();
|
||||
}
|
||||
};
|
||||
@@ -209,7 +215,7 @@ struct continue_statement : public statement {
|
||||
}
|
||||
};
|
||||
|
||||
value execute_impl(context &) override {
|
||||
[[noreturn]] value execute_impl(context &) override {
|
||||
throw continue_statement::signal();
|
||||
}
|
||||
};
|
||||
@@ -509,7 +515,7 @@ struct slice_expression : public expression {
|
||||
chk_type<expression>(this->step_expr);
|
||||
}
|
||||
std::string type() const override { return "SliceExpression"; }
|
||||
value execute_impl(context &) override {
|
||||
[[noreturn]] value execute_impl(context &) override {
|
||||
throw std::runtime_error("must be handled by MemberExpression");
|
||||
}
|
||||
};
|
||||
|
||||
@@ -590,6 +590,10 @@ static bool string_endswith(const std::string & str, const std::string & suffix)
|
||||
return str.compare(str.length() - suffix.length(), suffix.length(), suffix) == 0;
|
||||
}
|
||||
|
||||
[[noreturn]] static value string_join_not_implemented(const func_args &) {
|
||||
throw not_implemented_exception("String join builtin not implemented");
|
||||
}
|
||||
|
||||
const func_builtins & value_string_t::get_builtins() const {
|
||||
static const func_builtins builtins = {
|
||||
{"default", default_value},
|
||||
@@ -851,9 +855,7 @@ const func_builtins & value_string_t::get_builtins() const {
|
||||
res->val_str.mark_input_based_on(val_input->as_string());
|
||||
return res;
|
||||
}},
|
||||
{"join", [](const func_args &) -> value {
|
||||
throw not_implemented_exception("String join builtin not implemented");
|
||||
}},
|
||||
{"join", string_join_not_implemented},
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
@@ -884,6 +886,9 @@ const func_builtins & value_bool_t::get_builtins() const {
|
||||
return builtins;
|
||||
}
|
||||
|
||||
[[noreturn]] static value array_unique_not_implemented(const func_args &) {
|
||||
throw not_implemented_exception("Array unique builtin not implemented");
|
||||
}
|
||||
|
||||
const func_builtins & value_array_t::get_builtins() const {
|
||||
static const func_builtins builtins = {
|
||||
@@ -1084,13 +1089,14 @@ const func_builtins & value_array_t::get_builtins() const {
|
||||
std::reverse(arr.begin(), arr.end());
|
||||
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
|
||||
}},
|
||||
{"unique", [](const func_args &) -> value {
|
||||
throw not_implemented_exception("Array unique builtin not implemented");
|
||||
}},
|
||||
{"unique", array_unique_not_implemented},
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
|
||||
[[noreturn]] static value object_join_not_implemented(const func_args &) {
|
||||
throw not_implemented_exception("object join not implemented");
|
||||
}
|
||||
|
||||
const func_builtins & value_object_t::get_builtins() const {
|
||||
if (!has_builtins) {
|
||||
@@ -1183,9 +1189,7 @@ const func_builtins & value_object_t::get_builtins() const {
|
||||
});
|
||||
return result;
|
||||
}},
|
||||
{"join", [](const func_args &) -> value {
|
||||
throw not_implemented_exception("object join not implemented");
|
||||
}},
|
||||
{"join", object_join_not_implemented},
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
|
||||
@@ -129,27 +129,25 @@ struct value_t {
|
||||
// Note: only for debugging and error reporting purposes
|
||||
virtual std::string type() const { return ""; }
|
||||
|
||||
virtual int64_t as_int() const { throw std::runtime_error(type() + " is not an int value"); }
|
||||
virtual double as_float() const { throw std::runtime_error(type() + " is not a float value"); }
|
||||
virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); }
|
||||
virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); }
|
||||
virtual const std::vector<value> & as_array() const { throw std::runtime_error(type() + " is not an array value"); }
|
||||
virtual const std::vector<std::pair<value, value>> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); }
|
||||
virtual int64_t as_int() const { throw_type_error("is not an int value"); }
|
||||
virtual double as_float() const { throw_type_error("is not a float value"); }
|
||||
virtual string as_string() const { throw_type_error("is not a string value"); }
|
||||
virtual bool as_bool() const { throw_type_error("is not a bool value"); }
|
||||
virtual const std::vector<value> & as_array() const { throw_type_error("is not an array value"); }
|
||||
virtual const std::vector<std::pair<value, value>> & as_ordered_object() const { throw_type_error("is not an object value"); }
|
||||
virtual value invoke(const func_args &) const { throw_type_error("is not a function value"); }
|
||||
virtual bool is_none() const { return false; }
|
||||
virtual bool is_undefined() const { return false; }
|
||||
virtual const func_builtins & get_builtins() const {
|
||||
throw std::runtime_error("No builtins available for type " + type());
|
||||
}
|
||||
virtual const func_builtins & get_builtins() const { throw_type_error("has no builtins"); }
|
||||
|
||||
virtual bool has_key(const value &) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual void insert(const value & /* key */, const value & /* val */) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value & at(const value & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value & at(const value & /* key */) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value & at(const std::string & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value & at(const std::string & /* key */) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value & at(int64_t /* idx */, value & /* default_val */) { throw std::runtime_error(type() + " is not an array value"); }
|
||||
virtual value & at(int64_t /* idx */) { throw std::runtime_error(type() + " is not an array value"); }
|
||||
virtual bool has_key(const value &) { throw_type_error("is not an object value"); }
|
||||
virtual void insert(const value & /* key */, const value & /* val */) { throw_type_error("is not an object value"); }
|
||||
virtual value & at(const value & /* key */, value & /* default_val */) { throw_type_error("is not an object value"); }
|
||||
virtual value & at(const value & /* key */) { throw_type_error("is not an object value"); }
|
||||
virtual value & at(const std::string & /* key */, value & /* default_val */) { throw_type_error("is not an object value"); }
|
||||
virtual value & at(const std::string & /* key */) { throw_type_error("is not an object value"); }
|
||||
virtual value & at(int64_t /* idx */, value & /* default_val */) { throw_type_error("is not an array value"); }
|
||||
virtual value & at(int64_t /* idx */) { throw_type_error("is not an array value"); }
|
||||
|
||||
virtual bool is_numeric() const { return false; }
|
||||
virtual bool is_hashable() const { return false; }
|
||||
@@ -163,6 +161,11 @@ struct value_t {
|
||||
// Note: only for debugging purposes
|
||||
virtual std::string as_repr() const { return as_string().str(); }
|
||||
|
||||
private:
|
||||
[[noreturn]] void throw_type_error(const char* expected) const {
|
||||
throw std::runtime_error(type() + " " + expected);
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual bool equivalent(const value_t &) const = 0;
|
||||
virtual bool nonequal(const value_t & other) const { return !equivalent(other); }
|
||||
|
||||
@@ -61,18 +61,26 @@ static bool common_speculative_are_compatible(
|
||||
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
|
||||
|
||||
if (vocab_type_tgt != vocab_type_dft) {
|
||||
LOG_DBG("%s: draft model vocab type must match target model to use speculation but ", __func__);
|
||||
LOG_DBG("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
|
||||
LOG_WRN("%s: draft model vocab type must match target model to use speculation but "
|
||||
"vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (
|
||||
llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
|
||||
llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
|
||||
llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
|
||||
llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)
|
||||
) {
|
||||
LOG_DBG("%s: draft model special tokens must match target model to use speculation\n", __func__);
|
||||
if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
|
||||
(llama_vocab_get_add_bos(vocab_tgt) && llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft))) {
|
||||
LOG_WRN("%s: draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
|
||||
__func__,
|
||||
llama_vocab_get_add_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_dft),
|
||||
llama_vocab_bos(vocab_tgt), llama_vocab_bos(vocab_dft));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
|
||||
(llama_vocab_get_add_eos(vocab_tgt) && llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft))) {
|
||||
LOG_WRN("%s: draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
|
||||
__func__,
|
||||
llama_vocab_get_add_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_dft),
|
||||
llama_vocab_eos(vocab_tgt), llama_vocab_eos(vocab_dft));
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -51,6 +51,12 @@ The packages for FP32 and FP16 would have different accuracy and performance on
|
||||
|
||||
## News
|
||||
|
||||
- 2026.04
|
||||
|
||||
- Optimize mul_mat by reorder feature for data type: Q4_K, Q5_K, Q_K, Q8_0.
|
||||
- Fused MoE.
|
||||
- Upgrate CI and built package for oneAPI 2025.3.3, support Ubuntu 24.04 built package.
|
||||
|
||||
- 2026.03
|
||||
- Support Flash-Attention: less memory usage, performance impact depends on LLM.
|
||||
|
||||
@@ -349,6 +355,12 @@ Choose one of following methods to run.
|
||||
./examples/sycl/test.sh
|
||||
```
|
||||
|
||||
- Run llama-server:
|
||||
|
||||
```sh
|
||||
./examples/sycl/start-svr.sh -m PATH/MODEL_FILE
|
||||
```
|
||||
|
||||
2. Command line
|
||||
Launch inference
|
||||
|
||||
@@ -637,10 +649,18 @@ Choose one of following methods to run.
|
||||
|
||||
1. Script
|
||||
|
||||
- Run test:
|
||||
|
||||
```
|
||||
examples\sycl\win-test.bat
|
||||
```
|
||||
|
||||
- Run llama-server:
|
||||
|
||||
```
|
||||
examples\sycl\win-start-svr.bat -m PATH\MODEL_FILE
|
||||
```
|
||||
|
||||
2. Command line
|
||||
|
||||
Launch inference
|
||||
|
||||
@@ -26,7 +26,7 @@ Legend:
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
@@ -60,7 +60,7 @@ Legend:
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
@@ -105,7 +105,7 @@ Legend:
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SUM | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
|
||||
5501
docs/ops/WebGPU.csv
5501
docs/ops/WebGPU.csv
File diff suppressed because it is too large
Load Diff
124
examples/sycl/start-svr.sh
Executable file
124
examples/sycl/start-svr.sh
Executable file
@@ -0,0 +1,124 @@
|
||||
#!/bin/bash
|
||||
|
||||
# MIT license
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
Help() {
|
||||
cat << EOF
|
||||
Usage: $(basename "$0") [OPTIONS]
|
||||
|
||||
This script processes files with specified options.
|
||||
|
||||
Options:
|
||||
-h, --help Display this help message and exit.
|
||||
-c, --context <value> Set context length. Bigger need more memory.
|
||||
-p, --promote <value> Prompt to start generation with.
|
||||
-m, --model <value> Full model file path.
|
||||
-mg,--main-gpu <value> Set main GPU ID (0 - n) for single GPU mode.
|
||||
-sm,--split-mode <value> How to split the model across multiple GPUs, one of:
|
||||
- none: use one GPU only
|
||||
- layer (default): split layers and KV across GPUs
|
||||
- row: split rows across GPUs
|
||||
-ngl,--n-gpu-layers <value> Max. number of layers to store in VRAM (default: -1)
|
||||
-lv,--log-verbosity <value> Set the verbosity threshold. Messages with a higher verbosity will be
|
||||
ignored. Values:
|
||||
- 0: generic output
|
||||
- 1: error
|
||||
- 2: warning
|
||||
- 3: info
|
||||
- 4: debug
|
||||
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
BIN_FILE=./build/bin/llama-server
|
||||
SEED=0
|
||||
GPUS_SETTING=""
|
||||
|
||||
MODEL_FILE=../models/Qwen3.5-4B-Q4_0.gguf
|
||||
NGL=99
|
||||
CONTEXT=4096
|
||||
GGML_SYCL_DEVICE=-1
|
||||
SPLIT_MODE=layer
|
||||
LOG_VERBOSE=3
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
-c|--context)
|
||||
CONTEXT=$2
|
||||
# Shift twice to consume both the option flag and its value
|
||||
shift
|
||||
shift
|
||||
;;
|
||||
-m|--model)
|
||||
MODEL_FILE="$2"
|
||||
# Shift twice to consume both the option flag and its value
|
||||
shift
|
||||
shift
|
||||
;;
|
||||
-mg|--main-gpu)
|
||||
GGML_SYCL_DEVICE=$2
|
||||
SPLIT_MODE=none
|
||||
# Shift twice to consume both the option flag and its value
|
||||
shift
|
||||
shift
|
||||
;;
|
||||
-sm|--split-mode)
|
||||
SPLIT_MODE=$2
|
||||
# Shift twice to consume both the option flag and its value
|
||||
shift
|
||||
shift
|
||||
;;
|
||||
-ngl|--n-gpu-layers)
|
||||
NGL=$2
|
||||
# Shift twice to consume both the option flag and its value
|
||||
shift
|
||||
shift
|
||||
;;
|
||||
-lv|--log-verbosity)
|
||||
LOG_VERBOSE=$2
|
||||
# Shift twice to consume both the option flag and its value
|
||||
shift
|
||||
shift
|
||||
;;
|
||||
-h|--help)
|
||||
Help
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
# Handle unknown options or stop processing options
|
||||
echo "Invalid option: $1"
|
||||
# Optional: exit script or shift to treat remaining as positional args
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
|
||||
#export GGML_SYCL_DEBUG=1
|
||||
|
||||
#ZES_ENABLE_SYSMAN=1, Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory. Recommended to use when --split-mode = layer.
|
||||
|
||||
#support malloc device memory more than 4GB.
|
||||
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
|
||||
echo "UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=${UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS}"
|
||||
|
||||
if [ $GGML_SYCL_DEVICE -ne -1 ]; then
|
||||
echo "Use $GGML_SYCL_DEVICE as main GPU"
|
||||
#use signle GPU only
|
||||
GPUS_SETTING="-mg $GGML_SYCL_DEVICE -sm ${SPLIT_MODE}"
|
||||
export ONEAPI_DEVICE_SELECTOR="level_zero:${$GGML_SYCL_DEVICE}"
|
||||
echo "ONEAPI_DEVICE_SELECTOR=${ONEAPI_DEVICE_SELECTOR}"
|
||||
else
|
||||
echo "Use all Intel GPUs, including iGPU & dGPU"
|
||||
GPUS_SETTING="-sm ${SPLIT_MODE}"
|
||||
fi
|
||||
|
||||
echo "run cmd: ZES_ENABLE_SYSMAN=1 ${BIN_FILE} -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 200 -e -ngl ${NGL} -s ${SEED} -c ${CONTEXT} ${GPUS_SETTING} -lv ${LOG_VERBOSE} --mmap "
|
||||
ZES_ENABLE_SYSMAN=1 ${BIN_FILE} -m ${MODEL_FILE} -ngl ${NGL} -s ${SEED} -c ${CONTEXT} ${GPUS_SETTING} -lv ${LOG_VERBOSE} --mmap --host 0.0.0.0 --port 8000
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ SEED=0
|
||||
GPUS_SETTING=""
|
||||
|
||||
INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:"
|
||||
MODEL_FILE=models/llama-2-7b.Q4_0.gguf
|
||||
MODEL_FILE=../models/llama-2-7b.Q4_0.gguf
|
||||
NGL=99
|
||||
CONTEXT=4096
|
||||
GGML_SYCL_DEVICE=-1
|
||||
@@ -122,9 +122,10 @@ if [ $GGML_SYCL_DEVICE -ne -1 ]; then
|
||||
export ONEAPI_DEVICE_SELECTOR="level_zero:${$GGML_SYCL_DEVICE}"
|
||||
echo "ONEAPI_DEVICE_SELECTOR=${ONEAPI_DEVICE_SELECTOR}"
|
||||
else
|
||||
echo "Use all Intel GPUs, including iGPU & dGPU"
|
||||
echo "Use all Intel GPUs, including iGPU & dGPU"
|
||||
GPUS_SETTING="-sm ${SPLIT_MODE}"
|
||||
fi
|
||||
|
||||
echo "run cmd: ZES_ENABLE_SYSMAN=1 ${BIN_FILE} -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s ${SEED} -c ${CONTEXT} ${GPUS_SETTING} -lv ${LOG_VERBOSE} --mmap "
|
||||
ZES_ENABLE_SYSMAN=1 ${BIN_FILE} -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s ${SEED} -c ${CONTEXT} ${GPUS_SETTING} -lv ${LOG_VERBOSE} --mmap
|
||||
echo "run cmd: ZES_ENABLE_SYSMAN=1 ${BIN_FILE} -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 200 -e -ngl ${NGL} -s ${SEED} -c ${CONTEXT} ${GPUS_SETTING} -lv ${LOG_VERBOSE} --mmap "
|
||||
ZES_ENABLE_SYSMAN=1 ${BIN_FILE} -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 200 -e -ngl ${NGL} -s ${SEED} -c ${CONTEXT} ${GPUS_SETTING} -lv ${LOG_VERBOSE} --mmap
|
||||
|
||||
|
||||
179
examples/sycl/win-start-svr.bat
Normal file
179
examples/sycl/win-start-svr.bat
Normal file
@@ -0,0 +1,179 @@
|
||||
:: MIT license
|
||||
:: Copyright (C) 2024 Intel Corporation
|
||||
:: SPDX-License-Identifier: MIT
|
||||
|
||||
@echo off
|
||||
setlocal EnableExtensions EnableDelayedExpansion
|
||||
|
||||
set "BIN_FILE=.\build\bin\llama-server.exe"
|
||||
set "SEED=0"
|
||||
set "GPUS_SETTING="
|
||||
|
||||
set "MODEL_FILE=..\models\Qwen3.5-4B-Q4_0.gguf"
|
||||
set "NGL=99"
|
||||
set "CONTEXT=4096"
|
||||
set "GGML_SYCL_DEVICE=-1"
|
||||
set "SPLIT_MODE=layer"
|
||||
set "LOG_VERBOSE=3"
|
||||
|
||||
if "%~1"=="" goto after_args
|
||||
|
||||
:parse_args
|
||||
if "%~1"=="" goto after_args
|
||||
|
||||
if /I "%~1"=="-c" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "CONTEXT=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
if /I "%~1"=="--context" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "CONTEXT=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
|
||||
if /I "%~1"=="-m" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "MODEL_FILE=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
if /I "%~1"=="--model" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "MODEL_FILE=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
|
||||
if /I "%~1"=="-mg" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "GGML_SYCL_DEVICE=%~2"
|
||||
set "SPLIT_MODE=none"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
if /I "%~1"=="--main-gpu" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "GGML_SYCL_DEVICE=%~2"
|
||||
set "SPLIT_MODE=none"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
|
||||
if /I "%~1"=="-sm" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "SPLIT_MODE=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
if /I "%~1"=="--split-mode" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "SPLIT_MODE=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
|
||||
if /I "%~1"=="-ngl" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "NGL=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
if /I "%~1"=="--n-gpu-layers" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "NGL=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
|
||||
if /I "%~1"=="-lv" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "LOG_VERBOSE=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
if /I "%~1"=="--log-verbosity" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "LOG_VERBOSE=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
|
||||
if /I "%~1"=="-h" goto help
|
||||
if /I "%~1"=="--help" goto help
|
||||
|
||||
echo Invalid option: %~1
|
||||
exit /b 1
|
||||
|
||||
:missing_value
|
||||
echo Missing value for option: %~1
|
||||
exit /b 1
|
||||
|
||||
:help
|
||||
echo Usage: %~n0 [OPTIONS]
|
||||
echo.
|
||||
echo This script processes files with specified options.
|
||||
echo.
|
||||
echo Options:
|
||||
echo -h, --help Display this help message and exit.
|
||||
echo -c, --context ^<value^> Set context length. Bigger need more memory.
|
||||
echo -m, --model ^<value^> Full model file path.
|
||||
echo -mg,--main-gpu ^<value^> Set main GPU ID (0 - n) for single GPU mode.
|
||||
echo -sm,--split-mode ^<value^> How to split the model across multiple GPUs, one of:
|
||||
echo - none: use one GPU only
|
||||
echo - layer (default): split layers and KV across GPUs
|
||||
echo - row: split rows across GPUs
|
||||
echo -ngl,--n-gpu-layers ^<value^> Max. number of layers to store in VRAM (default: -1)
|
||||
echo -lv,--log-verbosity ^<value^> Set the verbosity threshold. Messages with a higher verbosity will be
|
||||
echo ignored. Values:
|
||||
echo - 0: generic output
|
||||
echo - 1: error
|
||||
echo - 2: warning
|
||||
echo - 3: info
|
||||
echo - 4: debug
|
||||
exit /b 0
|
||||
|
||||
:after_args
|
||||
|
||||
REM In Windows CMD, source is not available; call oneAPI setvars if present.
|
||||
if exist "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" (
|
||||
call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" >nul
|
||||
) else (
|
||||
echo Warning: oneAPI setvars.bat not found. Continuing without environment setup.
|
||||
)
|
||||
|
||||
REM Support malloc device memory more than 4GB.
|
||||
set "UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1"
|
||||
echo UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=%UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS%
|
||||
|
||||
if not "%GGML_SYCL_DEVICE%"=="-1" (
|
||||
echo Use %GGML_SYCL_DEVICE% as main GPU
|
||||
REM Use single GPU only.
|
||||
set "GPUS_SETTING=-mg %GGML_SYCL_DEVICE% -sm %SPLIT_MODE%"
|
||||
set "ONEAPI_DEVICE_SELECTOR=level_zero:%GGML_SYCL_DEVICE%"
|
||||
echo ONEAPI_DEVICE_SELECTOR=%ONEAPI_DEVICE_SELECTOR%
|
||||
) else (
|
||||
echo Use all Intel GPUs, including iGPU ^& dGPU
|
||||
set "GPUS_SETTING=-sm %SPLIT_MODE%"
|
||||
)
|
||||
|
||||
echo run cmd: ZES_ENABLE_SYSMAN=1 %BIN_FILE% -m "%MODEL_FILE%" -ngl %NGL% -s %SEED% -c %CONTEXT% %GPUS_SETTING% -lv %LOG_VERBOSE% --mmap --host 0.0.0.0 --port 8000
|
||||
set "ZES_ENABLE_SYSMAN=1"
|
||||
%BIN_FILE% -m "%MODEL_FILE%" -ngl %NGL% -s %SEED% -c %CONTEXT% %GPUS_SETTING% -lv %LOG_VERBOSE% --mmap --host 0.0.0.0 --port 8000
|
||||
|
||||
endlocal
|
||||
|
||||
@@ -2,10 +2,200 @@
|
||||
:: Copyright (C) 2024 Intel Corporation
|
||||
:: SPDX-License-Identifier: MIT
|
||||
|
||||
set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
|
||||
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
|
||||
|
||||
:: support malloc device memory more than 4GB.
|
||||
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
|
||||
set LOAD_MODE="--mmap"
|
||||
.\build\bin\llama-completion.exe -m models\llama-2-7b.Q4_0.gguf -no-cnv -p %INPUT2% -n 400 -e -ngl 99 -s 0 %LOAD_MODE%
|
||||
@echo off
|
||||
setlocal EnableExtensions EnableDelayedExpansion
|
||||
|
||||
REM MIT license
|
||||
REM Copyright (C) 2024 Intel Corporation
|
||||
REM SPDX-License-Identifier: MIT
|
||||
|
||||
set "BIN_FILE=.\build\bin\llama-completion.exe"
|
||||
set "SEED=0"
|
||||
set "GPUS_SETTING="
|
||||
|
||||
set "INPUT_PROMPT=Building a website can be done in 10 simple steps:^nStep 1:"
|
||||
set "MODEL_FILE=..\models\llama-2-7b.Q4_0.gguf"
|
||||
set "NGL=99"
|
||||
set "CONTEXT=4096"
|
||||
set "GGML_SYCL_DEVICE=-1"
|
||||
set "SPLIT_MODE=layer"
|
||||
set "LOG_VERBOSE=3"
|
||||
|
||||
if "%~1"=="" goto after_args
|
||||
|
||||
:parse_args
|
||||
if "%~1"=="" goto after_args
|
||||
|
||||
if /I "%~1"=="-c" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "CONTEXT=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
if /I "%~1"=="--context" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "CONTEXT=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
|
||||
if /I "%~1"=="-p" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "INPUT_PROMPT=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
if /I "%~1"=="--promote" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "INPUT_PROMPT=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
|
||||
if /I "%~1"=="-m" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "MODEL_FILE=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
if /I "%~1"=="--model" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "MODEL_FILE=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
|
||||
if /I "%~1"=="-mg" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "GGML_SYCL_DEVICE=%~2"
|
||||
set "SPLIT_MODE=none"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
if /I "%~1"=="--main-gpu" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "GGML_SYCL_DEVICE=%~2"
|
||||
set "SPLIT_MODE=none"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
|
||||
if /I "%~1"=="-sm" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "SPLIT_MODE=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
if /I "%~1"=="--split-mode" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "SPLIT_MODE=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
|
||||
if /I "%~1"=="-ngl" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "NGL=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
if /I "%~1"=="--n-gpu-layers" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "NGL=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
|
||||
if /I "%~1"=="-lv" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "LOG_VERBOSE=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
if /I "%~1"=="--log-verbosity" (
|
||||
if "%~2"=="" goto missing_value
|
||||
set "LOG_VERBOSE=%~2"
|
||||
shift
|
||||
shift
|
||||
goto parse_args
|
||||
)
|
||||
|
||||
if /I "%~1"=="-h" goto help
|
||||
if /I "%~1"=="--help" goto help
|
||||
|
||||
echo Invalid option: %~1
|
||||
exit /b 1
|
||||
|
||||
:missing_value
|
||||
echo Missing value for option: %~1
|
||||
exit /b 1
|
||||
|
||||
:help
|
||||
echo Usage: %~n0 [OPTIONS]
|
||||
echo.
|
||||
echo This script processes files with specified options.
|
||||
echo.
|
||||
echo Options:
|
||||
echo -h, --help Display this help message and exit.
|
||||
echo -c, --context ^<value^> Set context length. Bigger need more memory.
|
||||
echo -p, --promote ^<value^> Prompt to start generation with.
|
||||
echo -m, --model ^<value^> Full model file path.
|
||||
echo -mg,--main-gpu ^<value^> Set main GPU ID (0 - n) for single GPU mode.
|
||||
echo -sm,--split-mode ^<value^> How to split the model across multiple GPUs, one of:
|
||||
echo - none: use one GPU only
|
||||
echo - layer (default): split layers and KV across GPUs
|
||||
echo - row: split rows across GPUs
|
||||
echo -ngl,--n-gpu-layers ^<value^> Max. number of layers to store in VRAM (default: -1)
|
||||
echo -lv,--log-verbosity ^<value^> Set the verbosity threshold. Messages with a higher verbosity will be
|
||||
echo ignored. Values:
|
||||
echo - 0: generic output
|
||||
echo - 1: error
|
||||
echo - 2: warning
|
||||
echo - 3: info
|
||||
echo - 4: debug
|
||||
exit /b 0
|
||||
|
||||
:after_args
|
||||
|
||||
REM In Windows CMD, source is not available; call oneAPI setvars if present.
|
||||
if exist "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" (
|
||||
call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" >nul
|
||||
) else (
|
||||
echo Warning: oneAPI setvars.bat not found. Continuing without environment setup.
|
||||
)
|
||||
|
||||
REM Support malloc device memory more than 4GB.
|
||||
set "UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1"
|
||||
echo UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=%UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS%
|
||||
|
||||
if not "%GGML_SYCL_DEVICE%"=="-1" (
|
||||
echo Use %GGML_SYCL_DEVICE% as main GPU
|
||||
REM Use single GPU only.
|
||||
set "GPUS_SETTING=-mg %GGML_SYCL_DEVICE% -sm %SPLIT_MODE%"
|
||||
set "ONEAPI_DEVICE_SELECTOR=level_zero:%GGML_SYCL_DEVICE%"
|
||||
echo ONEAPI_DEVICE_SELECTOR=%ONEAPI_DEVICE_SELECTOR%
|
||||
) else (
|
||||
echo Use all Intel GPUs, including iGPU ^& dGPU
|
||||
set "GPUS_SETTING=-sm %SPLIT_MODE%"
|
||||
)
|
||||
|
||||
echo run cmd: ZES_ENABLE_SYSMAN=1 %BIN_FILE% -m %MODEL_FILE% -no-cnv -p "%INPUT_PROMPT%" -n 200 -e -ngl %NGL% -s %SEED% -c %CONTEXT% %GPUS_SETTING% -lv %LOG_VERBOSE% --mmap
|
||||
set "ZES_ENABLE_SYSMAN=1"
|
||||
%BIN_FILE% -m "%MODEL_FILE%" -no-cnv -p "%INPUT_PROMPT%" -n 200 -e -ngl %NGL% -s %SEED% -c %CONTEXT% %GPUS_SETTING% -lv %LOG_VERBOSE% --mmap
|
||||
|
||||
endlocal
|
||||
|
||||
|
||||
@@ -2300,9 +2300,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
|
||||
#if defined __AVX2__
|
||||
|
||||
const __m256i m4 = _mm256_set1_epi8(0xF);
|
||||
const __m256i m2 = _mm256_set1_epi8(3);
|
||||
const __m256i m32s = _mm256_set1_epi8(32);
|
||||
const __m256i m3 = _mm256_set1_epi8(3);
|
||||
const __m256i m15 = _mm256_set1_epi8(15);
|
||||
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
@@ -2314,53 +2313,45 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
||||
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
||||
|
||||
const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
|
||||
const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
|
||||
const __m256i scales_16 = _mm256_cvtepi8_epi16(scales);
|
||||
const __m256i q8sclsub = _mm256_slli_epi32(_mm256_madd_epi16(q8sums, scales_16), 5);
|
||||
|
||||
__m256i sumi = _mm256_setzero_si256();
|
||||
|
||||
int is = 0;
|
||||
|
||||
for (int j = 0; j < QK_K/128; ++j) {
|
||||
|
||||
const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
|
||||
const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
|
||||
const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
|
||||
const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
|
||||
is += 4;
|
||||
|
||||
const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
|
||||
const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
|
||||
const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32;
|
||||
|
||||
const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);
|
||||
const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);
|
||||
const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);
|
||||
const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);
|
||||
const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m3), 4);
|
||||
const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, _mm256_set1_epi8(12)), 2);
|
||||
const __m256i q4h_2 = _mm256_and_si256(q4bitsH, _mm256_set1_epi8(48));
|
||||
const __m256i q4h_3 = _mm256_srli_epi16(_mm256_and_si256(q4bitsH, _mm256_set1_epi8(-64)), 2);
|
||||
|
||||
const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
|
||||
const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
|
||||
const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
|
||||
const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
|
||||
const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m15), q4h_0);
|
||||
const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m15), q4h_1);
|
||||
const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m15), q4h_2);
|
||||
const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m15), q4h_3);
|
||||
|
||||
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
|
||||
__m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
|
||||
__m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
|
||||
__m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
|
||||
__m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
|
||||
|
||||
__m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
|
||||
__m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
|
||||
__m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
|
||||
__m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
|
||||
|
||||
p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
|
||||
p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
|
||||
p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
|
||||
p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
|
||||
const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
|
||||
const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
|
||||
const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
|
||||
const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
|
||||
is += 4;
|
||||
|
||||
p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
|
||||
p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
|
||||
@@ -2372,6 +2363,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
|
||||
}
|
||||
|
||||
sumi = _mm256_sub_epi32(sumi, q8sclsub);
|
||||
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
|
||||
|
||||
@@ -1036,12 +1036,12 @@ inline static float ggml_gelu_quick_f32(float x) {
|
||||
return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
|
||||
}
|
||||
|
||||
//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
|
||||
// const uint16_t * i16 = (const uint16_t *) x;
|
||||
// for (int i = 0; i < n; ++i) {
|
||||
// y[i] = ggml_table_gelu_quick_f16[i16[i]];
|
||||
// }
|
||||
//}
|
||||
inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
|
||||
const uint16_t * i16 = (const uint16_t *) x;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = ggml_table_gelu_quick_f16[i16[i]];
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_GELU_QUICK_FP16
|
||||
inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
|
||||
@@ -1060,13 +1060,6 @@ inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float *
|
||||
}
|
||||
#endif
|
||||
|
||||
inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
float v = GGML_CPU_FP16_TO_FP32(x[i]);
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(v*(1.0f/(1.0f+expf(GELU_QUICK_COEF*v))));
|
||||
}
|
||||
}
|
||||
|
||||
// Sigmoid Linear Unit (SiLU) function
|
||||
inline static float ggml_silu_f32(float x) {
|
||||
return x/(1.0f + expf(-x));
|
||||
|
||||
@@ -3478,10 +3478,10 @@ template <ggml_type type, int mmq_x, bool need_check>
|
||||
static __global__ void mul_mat_q(
|
||||
const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
|
||||
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
||||
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
|
||||
const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const int ncols_max) {
|
||||
const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
|
||||
const uint3 channel_ratio, const uint3 nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const uint3 sample_ratio, const uint3 nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const uint3 ntx) {
|
||||
|
||||
// Skip unused template specializations for faster compilation:
|
||||
if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
|
||||
@@ -3495,8 +3495,7 @@ static __global__ void mul_mat_q(
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int mmq_y = get_mmq_y_device();
|
||||
|
||||
const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
|
||||
const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
|
||||
const uint32_t nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
|
||||
|
||||
// Initialize the ids for writing back data with just the index.
|
||||
// For regular matrix multiplications this is never changed.
|
||||
@@ -3517,8 +3516,9 @@ static __global__ void mul_mat_q(
|
||||
// On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
|
||||
#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
||||
{
|
||||
const int wt = blockIdx.z / nchannels_y;
|
||||
const int zt = blockIdx.z - wt*nchannels_y;
|
||||
const uint2 tmp2 = fast_div_modulo(blockIdx.z, nchannels_y);
|
||||
const int wt = tmp2.x;
|
||||
const int zt = tmp2.y;
|
||||
const int jt = blockIdx.y;
|
||||
const int it = blockIdx.x;
|
||||
|
||||
@@ -3561,40 +3561,40 @@ static __global__ void mul_mat_q(
|
||||
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
||||
const int tile_y_max_j = col_diff - jt*mmq_x - 1;
|
||||
|
||||
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
||||
const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
||||
|
||||
constexpr bool fixup = false;
|
||||
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
|
||||
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
|
||||
tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
|
||||
tile_x_max_i, tile_y_max_j, 0, blocks_per_ne00.z);
|
||||
return;
|
||||
}
|
||||
#endif // (defined(GGML_USE_HIP) && !defined(CDNA4) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
||||
|
||||
constexpr int ITER_K = get_iter_k(type);
|
||||
|
||||
const int64_t blocks_per_ne00 = ncols_x / qk;
|
||||
constexpr int blocks_per_iter = ITER_K / qk;
|
||||
constexpr int ITER_K = get_iter_k(type);
|
||||
constexpr int blocks_per_iter = ITER_K / qk;
|
||||
|
||||
// kbc == k block continuous, current index in continuous ijk space.
|
||||
int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
|
||||
int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
|
||||
int kbc = int64_t(blockIdx.x) *(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
|
||||
int kbc_stop = int64_t(blockIdx.x + 1)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
|
||||
|
||||
kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
|
||||
kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
|
||||
kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter;
|
||||
kbc_stop -= fastmodulo(kbc_stop, blocks_per_ne00) % blocks_per_iter;
|
||||
|
||||
// kb0 == k index when doing the matrix multiplication for an output tile.
|
||||
int kb0_start = kbc % blocks_per_ne00;
|
||||
int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
|
||||
while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
|
||||
int tmp = kbc;
|
||||
const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
|
||||
tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
|
||||
const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
|
||||
tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
|
||||
const int zt = tmp / (ntx*blocks_per_ne00);
|
||||
tmp -= zt * (ntx*blocks_per_ne00);
|
||||
const int jt = tmp / blocks_per_ne00;
|
||||
int kb0_start = fastmodulo(kbc, blocks_per_ne00);
|
||||
int kb0_stop = min(blocks_per_ne00.z, uint32_t(kb0_start + kbc_stop - kbc));
|
||||
while (kbc < kbc_stop && kb0_stop == int(blocks_per_ne00.z)) {
|
||||
int tmp = fastdiv(kbc, blocks_per_ne00);
|
||||
uint2 tmp2 = fast_div_modulo(tmp, ntx);
|
||||
const int jt = tmp2.y;
|
||||
tmp = tmp2.x;
|
||||
tmp2 = fast_div_modulo(tmp, nchannels_y);
|
||||
const int zt = tmp2.y;
|
||||
tmp = tmp2.x;
|
||||
tmp2 = fast_div_modulo(tmp, nsamples_y);
|
||||
const int wt = tmp2.y;
|
||||
const int it = tmp2.x;
|
||||
|
||||
// Defaults for regular matrix multiplication:
|
||||
int col_low = 0;
|
||||
@@ -3612,11 +3612,11 @@ static __global__ void mul_mat_q(
|
||||
offset_dst = 0;
|
||||
|
||||
if (jt*mmq_x >= col_diff) {
|
||||
kbc += blocks_per_ne00;
|
||||
kbc -= kbc % blocks_per_ne00;
|
||||
kbc += blocks_per_ne00.z;
|
||||
kbc -= fastmodulo(kbc, blocks_per_ne00);
|
||||
|
||||
kb0_start = 0;
|
||||
kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
|
||||
kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc));
|
||||
|
||||
continue;
|
||||
}
|
||||
@@ -3641,32 +3641,34 @@ static __global__ void mul_mat_q(
|
||||
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
||||
const int tile_y_max_j = col_diff - jt*mmq_x - 1;
|
||||
|
||||
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
||||
const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
||||
|
||||
constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
||||
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
|
||||
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
|
||||
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
|
||||
|
||||
kbc += blocks_per_ne00;
|
||||
kbc -= kbc % blocks_per_ne00;
|
||||
kbc += blocks_per_ne00.z;
|
||||
kbc -= fastmodulo(kbc, blocks_per_ne00);
|
||||
|
||||
kb0_start = 0;
|
||||
kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
|
||||
kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc));
|
||||
}
|
||||
|
||||
if (kbc >= kbc_stop) {
|
||||
return;
|
||||
}
|
||||
|
||||
int tmp = kbc;
|
||||
const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
|
||||
tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
|
||||
const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
|
||||
tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
|
||||
const int zt = tmp / (ntx*blocks_per_ne00);
|
||||
tmp -= zt * (ntx*blocks_per_ne00);
|
||||
const int jt = tmp / blocks_per_ne00;
|
||||
int tmp = fastdiv(kbc, blocks_per_ne00);
|
||||
uint2 tmp2 = fast_div_modulo(tmp, ntx);
|
||||
const int jt = tmp2.y;
|
||||
tmp = tmp2.x;
|
||||
tmp2 = fast_div_modulo(tmp, nchannels_y);
|
||||
const int zt = tmp2.y;
|
||||
tmp = tmp2.x;
|
||||
tmp2 = fast_div_modulo(tmp, nsamples_y);
|
||||
const int wt = tmp2.y;
|
||||
const int it = tmp2.x;
|
||||
|
||||
// Defaults for regular matrix multiplication:
|
||||
int col_low = 0;
|
||||
@@ -3708,7 +3710,7 @@ static __global__ void mul_mat_q(
|
||||
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
||||
const int tile_y_max_j = col_diff - jt*mmq_x - 1;
|
||||
|
||||
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
||||
const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
||||
|
||||
constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
|
||||
@@ -3717,46 +3719,37 @@ static __global__ void mul_mat_q(
|
||||
}
|
||||
|
||||
template <ggml_type type, int mmq_x, bool need_check>
|
||||
static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
|
||||
const int32_t * expert_bounds,
|
||||
float * __restrict__ dst,
|
||||
const float * __restrict__ tmp_last_tile,
|
||||
const int ncols_x,
|
||||
const int nrows_x,
|
||||
const int ncols_dst,
|
||||
const size_t stride_col_dst,
|
||||
const int nchannels_y,
|
||||
const size_t stride_channel_dst,
|
||||
const int nsamples_y,
|
||||
const size_t stride_sample_dst,
|
||||
const int ncols_max) {
|
||||
constexpr int mmq_y = get_mmq_y_device();
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int ITER_K = get_iter_k(type);
|
||||
__launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device()/2, 1)
|
||||
static __global__ void mul_mat_q_stream_k_fixup(
|
||||
const int32_t * __restrict__ ids_dst, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
|
||||
float * __restrict__ tmp_last_tile, const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst,
|
||||
const int stride_col_dst, const uint3 nchannels_y, const int stride_channel_dst, const uint3 nsamples_y,
|
||||
const int stride_sample_dst, const uint3 ntx) {
|
||||
constexpr int mmq_y = get_mmq_y_device();
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int ITER_K = get_iter_k(type);
|
||||
constexpr int blocks_per_iter = ITER_K / qk;
|
||||
|
||||
constexpr int blocks_per_iter = ITER_K / qk;
|
||||
const int64_t blocks_per_ne00 = ncols_x / qk;
|
||||
|
||||
constexpr int nwarps = mmq_get_nwarps_device();
|
||||
constexpr int nwarps = mmq_get_nwarps_device()/2;
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
|
||||
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
|
||||
float sum[mmq_x / nwarps] = {0.0f};
|
||||
const int i = blockIdx.y*warp_size + threadIdx.x;
|
||||
|
||||
const int ntx = (ncols_max + mmq_x - 1) / mmq_x;
|
||||
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
|
||||
// kbc == k block continuous, current index in continuous ijk space.
|
||||
int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
|
||||
int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
|
||||
int kbc0 = int64_t(blockIdx.x) *(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
|
||||
int kbc0_stop = int64_t(blockIdx.x + 1)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
|
||||
|
||||
kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter;
|
||||
kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter;
|
||||
kbc0 -= fastmodulo(kbc0, blocks_per_ne00) % blocks_per_iter;
|
||||
kbc0_stop -= fastmodulo(kbc0_stop, blocks_per_ne00) % blocks_per_iter;
|
||||
|
||||
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
||||
const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0;
|
||||
const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0;
|
||||
const bool wrote_beginning_of_tile = fastmodulo(kbc0, blocks_per_ne00) == 0;
|
||||
const bool did_not_write_last = fastdiv(kbc0, blocks_per_ne00) == fastdiv(kbc0_stop, blocks_per_ne00) && fastmodulo(kbc0_stop, blocks_per_ne00) != 0;
|
||||
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
|
||||
return;
|
||||
}
|
||||
@@ -3765,11 +3758,11 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
|
||||
|
||||
// Iterate over previous blocks and sum up partial sums written to fixup buffer.
|
||||
// All CUDA blocks that get here must have a previous block that needs a fixup.
|
||||
int64_t bidx = bidx0 - 1;
|
||||
int64_t kbc_stop = kbc0;
|
||||
int bidx = bidx0 - 1;
|
||||
int kbc_stop = kbc0;
|
||||
while(true) {
|
||||
int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
|
||||
kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
|
||||
int kbc = int64_t(bidx)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
|
||||
kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter;
|
||||
|
||||
if (kbc == kbc_stop) { // Did not have any data.
|
||||
bidx--;
|
||||
@@ -3779,20 +3772,16 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
|
||||
|
||||
any_fixup = true;
|
||||
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
|
||||
}
|
||||
sum[j0/nwarps] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
|
||||
}
|
||||
|
||||
// If this block started in a previous tile we are done and don't need to combine additional partial results.
|
||||
if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) {
|
||||
if (fastmodulo(kbc, blocks_per_ne00) == 0 || fastdiv(kbc, blocks_per_ne00) < fastdiv(kbc0, blocks_per_ne00)) {
|
||||
break;
|
||||
}
|
||||
bidx--;
|
||||
@@ -3803,14 +3792,16 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
|
||||
return;
|
||||
}
|
||||
|
||||
int tmp = kbc0;
|
||||
const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
|
||||
tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
|
||||
const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
|
||||
tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
|
||||
const int zt = tmp / (ntx*blocks_per_ne00);
|
||||
tmp -= zt * (ntx*blocks_per_ne00);
|
||||
const int jt = tmp / blocks_per_ne00;
|
||||
int tmp = fastdiv(kbc0, blocks_per_ne00);
|
||||
uint2 tmp2 = fast_div_modulo(tmp, ntx);
|
||||
const int jt = tmp2.y;
|
||||
tmp = tmp2.x;
|
||||
tmp2 = fast_div_modulo(tmp, nchannels_y);
|
||||
const int zt = tmp2.y;
|
||||
tmp = tmp2.x;
|
||||
tmp2 = fast_div_modulo(tmp, nsamples_y);
|
||||
const int wt = tmp2.y;
|
||||
const int it = tmp2.x;
|
||||
|
||||
if (!ids_dst) {
|
||||
const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
|
||||
@@ -3818,6 +3809,9 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
|
||||
|
||||
const int i_max = nrows_x - it*mmq_y - 1;
|
||||
const int j_max = ncols_dst - jt*mmq_x - 1;
|
||||
if (need_check && i > i_max) {
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
||||
@@ -3827,16 +3821,7 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
if (need_check && i > i_max) {
|
||||
continue;
|
||||
}
|
||||
|
||||
dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
|
||||
}
|
||||
dst[j*stride_col_dst + i] += sum[j0/nwarps];
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -3856,6 +3841,9 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
|
||||
|
||||
const int i_max = nrows_x - it*mmq_y - 1;
|
||||
const int j_max = col_diff - jt*mmq_x - 1;
|
||||
if (need_check && i > i_max) {
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
||||
@@ -3865,16 +3853,7 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
if (need_check && i > i_max) {
|
||||
continue;
|
||||
}
|
||||
|
||||
dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
|
||||
}
|
||||
dst[ids_dst_shared[j]*stride_col_dst + i] += sum[j0/nwarps];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3922,29 +3901,44 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||
const int channel_ratio = args.nchannels_y / args.nchannels_x;
|
||||
const int sample_ratio = args.nsamples_y / args.nsamples_x;
|
||||
|
||||
const uint3 blocks_per_ne00_fd = init_fastdiv_values(args.ncols_x / ggml_cuda_type_traits<type>::qk);
|
||||
const uint3 ntx_fd = init_fastdiv_values(ntx);
|
||||
const uint3 nchannels_y_fd = init_fastdiv_values(args.nchannels_y);
|
||||
const uint3 nsamples_y_fd = init_fastdiv_values(args.nsamples_y);
|
||||
const uint3 channel_ratio_fd = init_fastdiv_values(channel_ratio);
|
||||
const uint3 sample_ratio_fd = init_fastdiv_values(sample_ratio);
|
||||
|
||||
if (!args.use_stream_k) {
|
||||
if (args.nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
||||
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
||||
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
ntx_fd);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
||||
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
||||
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
ntx_fd);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const dim3 block_nums_stream_k(nsm, 1, 1);
|
||||
const bool fixup_needed = ntx*nty*ntzw % nsm != 0;
|
||||
// For the stream-k kernel it is possible to run it with tiling by setting the number of CUDA blocks equal to the number of tiles.
|
||||
// This is worthwhile if the efficiency of tiling is high and skipping the fixup kernel is more important.
|
||||
const int ntiles_dst = ntx * nty * ntzw;
|
||||
const int tiles_nwaves = (ntiles_dst + nsm - 1) / nsm;
|
||||
const int tiles_efficiency_percent = 100 * ntiles_dst / (nsm*tiles_nwaves);
|
||||
const dim3 block_nums_stream_k(GGML_CUDA_CC_IS_NVIDIA(cc) && tiles_efficiency_percent >= 90 ? ntiles_dst : nsm, 1, 1);
|
||||
|
||||
GGML_ASSERT(ntiles_dst * blocks_per_ne00_fd.z < (1 << 30)); // Assert that variable kbc will not overflow.
|
||||
|
||||
const bool fixup_needed = ntiles_dst % block_nums_stream_k.x != 0;
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool(id);
|
||||
ggml_cuda_pool_alloc<float> tmp_fixup(pool);
|
||||
@@ -3952,40 +3946,45 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||
tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y);
|
||||
}
|
||||
|
||||
const dim3 block_nums_fixup(block_nums_stream_k.x, mmq_y/warp_size, 1);
|
||||
const dim3 block_dims_fixup(block_dims.x, block_dims.y/2, block_dims.z);
|
||||
|
||||
if (args.nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
||||
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
||||
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
ntx_fd);
|
||||
|
||||
if (!fixup_needed) {
|
||||
return;
|
||||
}
|
||||
|
||||
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
||||
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
|
||||
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_fixup, block_dims_fixup, 0, stream>>>
|
||||
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst,
|
||||
args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst,
|
||||
ntx_fd);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
||||
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
||||
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
ntx_fd);
|
||||
|
||||
if (!fixup_needed) {
|
||||
return;
|
||||
}
|
||||
|
||||
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
||||
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
|
||||
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_fixup, block_dims_fixup, 0, stream>>>
|
||||
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst,
|
||||
args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst,
|
||||
ntx_fd);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1683,7 +1683,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
|
||||
__fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256);
|
||||
assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget);
|
||||
|
||||
FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", __func__, m, k, n, weight_type,
|
||||
FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", m, k, n, weight_type,
|
||||
M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget);
|
||||
|
||||
// initialize eye tile (32x32 identity matrix)
|
||||
|
||||
@@ -101,6 +101,26 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
|
||||
}
|
||||
}
|
||||
|
||||
#if __HVX_ARCH__ >= 75
|
||||
{
|
||||
// Set HMX clock
|
||||
HAP_power_request_t request;
|
||||
memset(&request, 0, sizeof(HAP_power_request_t));
|
||||
request.type = HAP_power_set_HMX_v2;
|
||||
request.hmx_v2.set_clock = TRUE;
|
||||
request.hmx_v2.target_corner = HAP_DCVS_EXP_VCORNER_MAX;
|
||||
request.hmx_v2.min_corner = HAP_DCVS_EXP_VCORNER_MAX;
|
||||
request.hmx_v2.max_corner = HAP_DCVS_EXP_VCORNER_MAX;
|
||||
request.hmx_v2.perf_mode = HAP_CLK_PERF_HIGH;
|
||||
FARF(ALWAYS, "Setting HMX clock\n");
|
||||
err = HAP_power_set((void *) &ctx, &request);
|
||||
if (err != AEE_SUCCESS) {
|
||||
FARF(ERROR, "Error setting HMX clock.");
|
||||
return err;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ CatalogFile = libggml-htp.cat
|
||||
PnpLockDown = 1
|
||||
|
||||
[DestinationDirs]
|
||||
Drivers_Dir = 6
|
||||
Drivers_Dir = 13
|
||||
|
||||
[SourceDisksNames]
|
||||
1 = %DiskId%
|
||||
|
||||
@@ -677,7 +677,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta
|
||||
const ggml_type tsrc1 = op->src[1]->type;
|
||||
|
||||
const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
|
||||
const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0;
|
||||
|
||||
constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y;
|
||||
constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X;
|
||||
|
||||
const bool has_tensor = ggml_metal_device_get_props(ggml_metal_library_get_device(lib))->has_tensor;
|
||||
|
||||
const bool bc_out = has_tensor
|
||||
? (op->ne[0] % NRA != 0 || op->ne[1] % NRB != 0)
|
||||
: (op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0);
|
||||
|
||||
snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
|
||||
snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
|
||||
@@ -694,8 +702,20 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta
|
||||
ggml_metal_cv_free(cv);
|
||||
}
|
||||
|
||||
// when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
|
||||
res.smem = bc_out ? 8192 : 4096 + 2048;
|
||||
if (has_tensor) {
|
||||
res.nr0 = NRA;
|
||||
res.nr1 = NRB;
|
||||
|
||||
const size_t smem_a = NRA * N_MM_NK_TOTAL * sizeof(ggml_fp16_t);
|
||||
res.smem = smem_a;
|
||||
} else {
|
||||
res.nr0 = 64;
|
||||
res.nr1 = 32;
|
||||
|
||||
res.smem = bc_out ? 8192 : (4096 + 2048);
|
||||
}
|
||||
|
||||
res.nsg = N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -102,6 +102,8 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
|
||||
|
||||
void ggml_metal_library_free(ggml_metal_library_t lib);
|
||||
|
||||
ggml_metal_device_t ggml_metal_library_get_device(ggml_metal_library_t lib);
|
||||
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv);
|
||||
|
||||
|
||||
@@ -95,8 +95,8 @@ int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_wi
|
||||
|
||||
struct ggml_metal_library {
|
||||
id<MTLLibrary> obj;
|
||||
id<MTLDevice> device;
|
||||
|
||||
ggml_metal_device_t dev;
|
||||
ggml_metal_pipelines_t pipelines; // cache of compiled pipelines
|
||||
|
||||
NSLock * lock;
|
||||
@@ -251,7 +251,7 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
|
||||
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
|
||||
|
||||
res->obj = library;
|
||||
res->device = device;
|
||||
res->dev = dev;
|
||||
res->pipelines = ggml_metal_pipelines_init();
|
||||
res->lock = [NSLock new];
|
||||
|
||||
@@ -318,7 +318,7 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
|
||||
}
|
||||
|
||||
res->obj = library;
|
||||
res->device = device;
|
||||
res->dev = dev;
|
||||
res->pipelines = ggml_metal_pipelines_init();
|
||||
res->lock = [NSLock new];
|
||||
|
||||
@@ -341,6 +341,10 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
|
||||
free(lib);
|
||||
}
|
||||
|
||||
ggml_metal_device_t ggml_metal_library_get_device(ggml_metal_library_t lib) {
|
||||
return lib->dev;
|
||||
}
|
||||
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
|
||||
[lib->lock lock];
|
||||
|
||||
@@ -405,7 +409,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_
|
||||
return res;
|
||||
}
|
||||
|
||||
id<MTLComputePipelineState> obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];
|
||||
id<MTLDevice> device = ggml_metal_device_get_obj(lib->dev);
|
||||
id<MTLComputePipelineState> obj = [device newComputePipelineStateWithFunction:mtl_function error:&error];
|
||||
|
||||
[mtl_function release];
|
||||
|
||||
@@ -699,7 +704,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) {
|
||||
" auto sB = tB.slice(0, 0); \n"
|
||||
" mm.run(sB, sA, cT); \n"
|
||||
" \n"
|
||||
" auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
|
||||
" auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(16, 16)); \n"
|
||||
" \n"
|
||||
" cT.store(tC); \n"
|
||||
"}";
|
||||
@@ -749,7 +754,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) {
|
||||
" auto sB = tB.slice(0, 0); \n"
|
||||
" mm.run(sB, sA, cT); \n"
|
||||
" \n"
|
||||
" auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
|
||||
" auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(16, 16)); \n"
|
||||
" \n"
|
||||
" cT.store(tC); \n"
|
||||
"}";
|
||||
@@ -814,7 +819,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) {
|
||||
}
|
||||
|
||||
// print MTL GPU family:
|
||||
GGML_LOG_INFO("%s: GPU name: %s\n", __func__, dev->props.name);
|
||||
GGML_LOG_INFO("%s: GPU name: %s (%s)\n", __func__, dev->props.name, dev->props.desc);
|
||||
|
||||
// determine max supported GPU family
|
||||
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||
|
||||
@@ -1,6 +1,19 @@
|
||||
#ifndef GGML_METAL_IMPL
|
||||
#define GGML_METAL_IMPL
|
||||
|
||||
// kernel parameters for mat-mat threadgroups
|
||||
//
|
||||
// TODO: become function constants
|
||||
|
||||
#define SZ_SIMDGROUP 16
|
||||
#define N_MM_NK 2
|
||||
#define N_MM_NK_TOTAL (SZ_SIMDGROUP * N_MM_NK)
|
||||
|
||||
#define N_MM_BLOCK_X 4
|
||||
#define N_MM_BLOCK_Y 2
|
||||
#define N_MM_SIMD_GROUP_X 2
|
||||
#define N_MM_SIMD_GROUP_Y 2
|
||||
|
||||
// kernel parameters for mat-vec threadgroups
|
||||
//
|
||||
// N_R0: number of src0 rows to process per simdgroup
|
||||
|
||||
@@ -2195,7 +2195,12 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);
|
||||
|
||||
const int nr0 = pipeline.nr0;
|
||||
const int nr1 = pipeline.nr1;
|
||||
const int nsg = pipeline.nsg;
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + nr1 - 1) / nr1), ((ne01 + nr0 - 1) / nr0), ne12 * ne13, 32, nsg, 1);
|
||||
} else {
|
||||
auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
|
||||
|
||||
|
||||
@@ -9306,7 +9306,137 @@ constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
|
||||
constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
|
||||
|
||||
// each block_q contains 16*nl weights
|
||||
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
#ifdef GGML_METAL_HAS_TENSOR
|
||||
template<
|
||||
typename SA, typename SA_4x4, typename SA_8x8,
|
||||
typename SB, typename SB_2x4, typename SB_8x8,
|
||||
typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &),
|
||||
typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm(
|
||||
constant ggml_metal_kargs_mul_mm & args,
|
||||
device const char * srcA,
|
||||
device const char * srcB,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig [[threadgroup_position_in_grid]],
|
||||
ushort tiitg [[thread_index_in_threadgroup]],
|
||||
ushort sgitg [[simdgroup_index_in_threadgroup]]) {
|
||||
(void) sgitg;
|
||||
|
||||
// Matrix dimensions: A(M,K) x B(K,N) -> C(M,N)
|
||||
const int K = args.ne00;
|
||||
const int M = args.ne0;
|
||||
const int N = args.ne1;
|
||||
|
||||
// Batch dimension handling
|
||||
const int im = tgpig.z;
|
||||
const int i12 = im % args.ne12;
|
||||
const int i13 = im / args.ne12;
|
||||
|
||||
// Batch offsets for srcA and srcB
|
||||
const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
|
||||
// Tile dimensions
|
||||
constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X;
|
||||
constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y;
|
||||
|
||||
// Tile offsets in output matrix
|
||||
const int ra = tgpig.y * NRA;
|
||||
const int rb = tgpig.x * NRB;
|
||||
|
||||
// Threadgroup memory for dequantized A tile only
|
||||
threadgroup SA * sa = (threadgroup SA *)(shmem);
|
||||
|
||||
// Work-item count for A loading
|
||||
constexpr int A_WORK_ITEMS = NRA * N_MM_NK;
|
||||
constexpr int NUM_THREADS = N_SIMDWIDTH * N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y;
|
||||
|
||||
// tA wraps threadgroup memory
|
||||
auto tA = tensor(sa, dextents<int32_t, 2>(N_MM_NK_TOTAL, NRA));
|
||||
|
||||
// tB wraps device memory directly
|
||||
device T1 * ptrB = (device T1 *)(srcB + args.nb12*i12 + args.nb13*i13);
|
||||
const int strideB = args.nb11 / sizeof(T1);
|
||||
auto tB = tensor(ptrB, dextents<int32_t, 2>(K, N), array<int, 2>({1, strideB}));
|
||||
|
||||
// Configure matmul operation
|
||||
mpp::tensor_ops::matmul2d<
|
||||
mpp::tensor_ops::matmul2d_descriptor(
|
||||
NRB, NRA, N_MM_NK_TOTAL, false, true, true,
|
||||
mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
|
||||
execution_simdgroups<N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y>> mm;
|
||||
|
||||
auto cT = mm.get_destination_cooperative_tensor<decltype(tB), decltype(tA), float>();
|
||||
|
||||
// Accumulate partial results over K dimension
|
||||
for (int loop_k = 0; loop_k < K; loop_k += N_MM_NK_TOTAL) {
|
||||
// === PHASE 1: Dequantization of A into threadgroup memory ===
|
||||
for (int work = tiitg; work < A_WORK_ITEMS; work += NUM_THREADS) {
|
||||
const int row = work / N_MM_NK;
|
||||
const int k_chunk = work % N_MM_NK;
|
||||
const int k_pos = loop_k + k_chunk * 16;
|
||||
const short k_base = k_chunk * 16;
|
||||
|
||||
// Bounds check: skip device read if row is out of matrix bounds
|
||||
if (ra + row < M) {
|
||||
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
||||
// Element-wise reads when K is not aligned (nb01 not aligned for half4x4/float4x4).
|
||||
// MSL spec Table 2.5: half4x4 requires 8-byte alignment. When K is odd,
|
||||
// nb01 = K*2 is not 8-byte aligned, so odd-row pointers are misaligned.
|
||||
// Mirrors the legacy kernel's existing guard.
|
||||
device const T0 * row_ptr = (device const T0 *)(srcA + args.nb01 * (ra + row) + offset0);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? (SA) row_ptr[k_pos + i] : (SA)0;
|
||||
}
|
||||
} else {
|
||||
const int block_idx = k_pos / (16 * nl);
|
||||
const short il = (k_pos / 16) % nl;
|
||||
|
||||
device const block_q * row_ptr = (device const block_q *)(srcA + args.nb01 * (ra + row) + offset0);
|
||||
|
||||
SA_4x4 temp_a;
|
||||
dequantize_func(row_ptr + block_idx, il, temp_a);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
// Zero-pad A for K positions beyond valid range (handles partial K iterations)
|
||||
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? temp_a[i/4][i%4] : (SA)0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Zero-pad rows beyond matrix bounds
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (SA)0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// === PHASE 2: Tensor matmul ===
|
||||
auto mA = tA.slice(0, 0);
|
||||
auto mB = tB.slice(loop_k, rb);
|
||||
|
||||
mm.run(mB, mA, cT);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
// Store result tile to output matrix (with batch offset)
|
||||
// cT.store handles bounds checking via tD's extents (M, N)
|
||||
device float * dstBatch = (device float *)dst + im * N * M;
|
||||
|
||||
auto tD = tensor(dstBatch, dextents<int32_t, 2>(M, N), array<int, 2>({1, M}));
|
||||
cT.store(tD.slice(ra, rb));
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
template<
|
||||
typename S0, typename S0_4x4, typename S0_8x8,
|
||||
typename S1, typename S1_2x4, typename S1_8x8,
|
||||
typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &),
|
||||
typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm(
|
||||
constant ggml_metal_kargs_mul_mm & args,
|
||||
device const char * src0,
|
||||
@@ -9320,10 +9450,6 @@ kernel void kernel_mul_mm(
|
||||
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
||||
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
||||
|
||||
#ifdef GGML_METAL_HAS_TENSOR
|
||||
threadgroup float * sc = (threadgroup float *)(shmem);
|
||||
#endif
|
||||
|
||||
constexpr int NR0 = 64;
|
||||
constexpr int NR1 = 32;
|
||||
|
||||
@@ -9363,7 +9489,6 @@ kernel void kernel_mul_mm(
|
||||
+ args.nb11*(r1 + lr1)
|
||||
+ args.nb10*iy);
|
||||
|
||||
#ifndef GGML_METAL_HAS_TENSOR
|
||||
S0_8x8 ma[4];
|
||||
S1_8x8 mb[2];
|
||||
|
||||
@@ -9372,19 +9497,8 @@ kernel void kernel_mul_mm(
|
||||
for (short i = 0; i < 8; i++){
|
||||
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
||||
}
|
||||
#else
|
||||
auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
|
||||
auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
|
||||
|
||||
mpp::tensor_ops::matmul2d<
|
||||
mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
|
||||
execution_simdgroups<4>> mm;
|
||||
|
||||
auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
|
||||
#endif
|
||||
|
||||
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
|
||||
#ifndef GGML_METAL_HAS_TENSOR
|
||||
// load data and store to threadgroup memory
|
||||
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@@ -9454,66 +9568,6 @@ kernel void kernel_mul_mm(
|
||||
|
||||
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
|
||||
}
|
||||
#else
|
||||
// load data and store to threadgroup memory
|
||||
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// no need for dequantization
|
||||
for (short i = 0; i < 16; i++) {
|
||||
const short sx = 2*il0 + i/8;
|
||||
const short sy = (tiitg/NL0)/8;
|
||||
|
||||
const short lx = i%8;
|
||||
const short ly = (tiitg/NL0)%8;
|
||||
//const short lx = (tiitg/NL0)%8;
|
||||
//const short ly = i%8;
|
||||
|
||||
*(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
|
||||
}
|
||||
} else {
|
||||
S0_4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
const short sx = 2*il0 + i/8;
|
||||
const short sy = (tiitg/NL0)/8;
|
||||
|
||||
const short lx = i%8;
|
||||
const short ly = (tiitg/NL0)%8;
|
||||
//const short lx = (tiitg/NL0)%8;
|
||||
//const short ly = i%8;
|
||||
|
||||
*(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_mul_mm_bc_inp) {
|
||||
for (short i = 0; i < 8; ++i) {
|
||||
const short sx = (tiitg%NL1);
|
||||
const short sy = (tiitg/NL1)/8;
|
||||
|
||||
const short lx = i;
|
||||
const short ly = (tiitg/NL1)%8;
|
||||
//const short lx = (tiitg/NL1)%8;
|
||||
//const short ly = i;
|
||||
|
||||
*(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
|
||||
}
|
||||
} else {
|
||||
const short sx = (tiitg%NL1);
|
||||
const short sy = (tiitg/NL1)/8;
|
||||
|
||||
//const short lx = i;
|
||||
const short ly = (tiitg/NL1)%8;
|
||||
//const short lx = (tiitg/NL1)%8;
|
||||
//const short ly = i;
|
||||
|
||||
*(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
|
||||
}
|
||||
#endif
|
||||
|
||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
||||
@@ -9522,7 +9576,6 @@ kernel void kernel_mul_mm(
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
#ifndef GGML_METAL_HAS_TENSOR
|
||||
// load matrices from threadgroup memory and conduct outer products
|
||||
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
|
||||
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
|
||||
@@ -9549,24 +9602,10 @@ kernel void kernel_mul_mm(
|
||||
lsma += 8*64;
|
||||
lsmb += 4*64;
|
||||
}
|
||||
#else
|
||||
auto sA = tA.slice(0, 0);
|
||||
auto sB = tB.slice(0, 0);
|
||||
|
||||
mm.run(sB, sA, cT);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
|
||||
// if no bounds checks on the output are needed, we can directly write to device memory
|
||||
#ifdef GGML_METAL_HAS_TENSOR
|
||||
device float * C = (device float *) dst +
|
||||
r0 + \
|
||||
r1 * args.ne0 + im*args.ne1*args.ne0;
|
||||
|
||||
auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(args.ne0, NR1));
|
||||
cT.store(tC);
|
||||
#else
|
||||
device float * C = (device float *) dst +
|
||||
(r0 + 32*(sgitg & 1)) + \
|
||||
(r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
|
||||
@@ -9574,21 +9613,15 @@ kernel void kernel_mul_mm(
|
||||
for (short i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
|
||||
|
||||
#ifdef GGML_METAL_HAS_TENSOR
|
||||
auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
|
||||
cT.store(tC);
|
||||
#else
|
||||
for (short i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
|
||||
}
|
||||
#endif
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
@@ -9614,6 +9647,8 @@ kernel void kernel_mul_mm(
|
||||
}
|
||||
}
|
||||
|
||||
#endif // GGML_METAL_HAS_TENSOR
|
||||
|
||||
template<short ne20> // n_expert_used
|
||||
kernel void kernel_mul_mm_id_map0(
|
||||
constant ggml_metal_kargs_mul_mm_id_map0 & args,
|
||||
@@ -9789,7 +9824,7 @@ kernel void kernel_mul_mm_id(
|
||||
|
||||
const short ib = 8*sx + sy;
|
||||
|
||||
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
|
||||
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? (S0) *((device T0 *) x + i) : (S0) 0;
|
||||
}
|
||||
} else {
|
||||
S0_4x4 temp_a;
|
||||
|
||||
@@ -96,6 +96,8 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mv_q6_k_f32_flat
|
||||
mul_mv_q8_0_f32
|
||||
mul_mv_q8_0_f32_flat
|
||||
mul_mv_iq4_nl_f32
|
||||
mul_mv_iq4_nl_f32_flat
|
||||
mul_mv_mxfp4_f32
|
||||
mul_mv_mxfp4_f32_flat
|
||||
mul_mv_id_q4_0_f32_8x_flat
|
||||
@@ -110,12 +112,15 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mm_q4_0_f32_l4_lm
|
||||
mul_mm_q4_1_f32_l4_lm
|
||||
mul_mm_q8_0_f32_l4_lm
|
||||
mul_mm_iq4_nl_f32_l4_lm
|
||||
mul_mm_q4_k_f32_l4_lm
|
||||
mul_mm_q5_k_f32_l4_lm
|
||||
mul_mm_q6_k_f32_l4_lm
|
||||
mul_mm_q8_0_f32_8x4
|
||||
gemv_noshuffle_q4_1_f32
|
||||
gemm_noshuffle_q4_1_f32
|
||||
gemv_noshuffle_iq4_nl_f32
|
||||
gemm_noshuffle_iq4_nl_f32
|
||||
gemv_noshuffle_general_q8_0_f32
|
||||
gemv_noshuffle_q4_k_f32
|
||||
gemm_noshuffle_q4_k_f32
|
||||
|
||||
@@ -545,6 +545,9 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_convert_block_q5_K_noshuffle;
|
||||
cl_kernel kernel_restore_block_q5_K_noshuffle;
|
||||
cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K;
|
||||
cl_kernel kernel_convert_block_iq4_nl, kernel_restore_block_iq4_nl;
|
||||
cl_kernel kernel_convert_block_iq4_nl_noshuffle;
|
||||
cl_kernel kernel_restore_block_iq4_nl_noshuffle;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
|
||||
cl_kernel kernel_mul_mv_q4_1_f32;
|
||||
cl_kernel kernel_mul_mv_q4_1_f32_flat;
|
||||
@@ -556,6 +559,8 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_mul_mv_q6_K_f32_flat;
|
||||
cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat;
|
||||
cl_kernel kernel_mul_mv_q8_0_f32, kernel_mul_mv_q8_0_f32_flat;
|
||||
cl_kernel kernel_mul_mv_iq4_nl_f32;
|
||||
cl_kernel kernel_mul_mv_iq4_nl_f32_flat;
|
||||
cl_kernel kernel_solve_tri_f32;
|
||||
cl_kernel kernel_im2col_f32, kernel_im2col_f16;
|
||||
cl_kernel kernel_argsort_f32_i32;
|
||||
@@ -594,6 +599,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_mul_mm_q4_k_f32_l4_lm;
|
||||
cl_kernel kernel_mul_mm_q5_k_f32_l4_lm;
|
||||
cl_kernel kernel_mul_mm_q6_k_f32_l4_lm;
|
||||
cl_kernel kernel_mul_mm_iq4_nl_f32_l4_lm;
|
||||
|
||||
std::vector<ProfilingInfo> profiling_info;
|
||||
|
||||
@@ -734,6 +740,8 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_gemm_noshuffle_q6_K_f32;
|
||||
cl_kernel kernel_gemv_noshuffle_q5_k_f32;
|
||||
cl_kernel kernel_gemm_noshuffle_q5_k_f32;
|
||||
cl_kernel kernel_gemv_noshuffle_iq4_nl_f32;
|
||||
cl_kernel kernel_gemm_noshuffle_iq4_nl_f32;
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
|
||||
void free() {
|
||||
@@ -954,6 +962,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K_noshuffle", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K_noshuffle", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_iq4_nl = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_iq4_nl", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_iq4_nl = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_iq4_nl", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_iq4_nl_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_iq4_nl_noshuffle", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_iq4_nl_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_iq4_nl_noshuffle", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
@@ -1359,6 +1371,40 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mv_iq4_nl_f32
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mv_iq4_nl_f32.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mv_iq4_nl_f32.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mv_iq4_nl_f32 = clCreateKernel(prog, "kernel_mul_mv_iq4_nl_f32", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mv_iq4_nl_f32_flat
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mv_iq4_nl_f32_flat.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mv_iq4_nl_f32_flat.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mv_iq4_nl_f32_flat = clCreateKernel(prog, "kernel_mul_mv_iq4_nl_f32_flat", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mv_mxfp4_f32
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -1567,6 +1613,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mm_iq4_nl_f32_l4_lm
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mm_iq4_nl_f32_l4_lm.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mm_iq4_nl_f32_l4_lm.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mm_iq4_nl_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_iq4_nl_f32_l4_lm", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mm_q4_k_f32_l4_lm
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -2647,6 +2710,45 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// gemm_noshuffle_iq4_nl_f32
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "gemm_noshuffle_iq4_nl_f32.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("gemm_noshuffle_iq4_nl_f32.cl");
|
||||
#endif
|
||||
cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
CL_CHECK((backend_ctx->kernel_gemm_noshuffle_iq4_nl_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_iq4_nl_f32", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// gemv_noshuffle_iq4_nl_f32
|
||||
{
|
||||
std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std +
|
||||
" -cl-mad-enable ";
|
||||
if (backend_ctx->has_vector_subgroup_broadcast) {
|
||||
CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST ";
|
||||
}
|
||||
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "gemv_noshuffle_iq4_nl_f32.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("gemv_noshuffle_iq4_nl_f32.cl");
|
||||
#endif
|
||||
|
||||
cl_program prog = build_program_from_source(
|
||||
backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_gemv_noshuffle_iq4_nl_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_iq4_nl_f32", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mm_q8_0_f32_8x4
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -3597,6 +3699,30 @@ struct ggml_tensor_extra_cl_q8_0 {
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_tensor_extra_cl_iq4_nl {
|
||||
cl_mem q = nullptr;
|
||||
cl_mem q_img = nullptr;
|
||||
|
||||
cl_mem d = nullptr;
|
||||
cl_mem d_img = nullptr;
|
||||
|
||||
size_t size_q = 0;
|
||||
size_t size_d = 0;
|
||||
|
||||
~ggml_tensor_extra_cl_iq4_nl() {
|
||||
reset();
|
||||
}
|
||||
|
||||
void reset() {
|
||||
if (q != nullptr) { CL_CHECK(clReleaseMemObject(q)); q = nullptr; }
|
||||
if (d != nullptr) { CL_CHECK(clReleaseMemObject(d)); d = nullptr; }
|
||||
q_img = nullptr;
|
||||
d_img = nullptr;
|
||||
size_q = 0;
|
||||
size_d = 0;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_tensor_extra_cl_q4_K {
|
||||
// Quantized values
|
||||
cl_mem q = nullptr;
|
||||
@@ -4097,6 +4223,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
return op->src[1]->type == GGML_TYPE_F32;
|
||||
} else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 ||
|
||||
op->src[0]->type == GGML_TYPE_MXFP4 ||
|
||||
op->src[0]->type == GGML_TYPE_IQ4_NL ||
|
||||
op->src[0]->type == GGML_TYPE_Q4_K ||
|
||||
op->src[0]->type == GGML_TYPE_Q5_K ||
|
||||
op->src[0]->type == GGML_TYPE_Q6_K) {
|
||||
@@ -4295,6 +4422,12 @@ struct ggml_backend_opencl_buffer_context {
|
||||
for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) {
|
||||
delete e;
|
||||
}
|
||||
for (ggml_tensor_extra_cl_iq4_nl * e : temp_tensor_extras_iq4_nl) {
|
||||
delete e;
|
||||
}
|
||||
for (ggml_tensor_extra_cl_iq4_nl * e : temp_tensor_extras_iq4_nl_in_use) {
|
||||
delete e;
|
||||
}
|
||||
for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K) {
|
||||
delete e;
|
||||
}
|
||||
@@ -4390,6 +4523,21 @@ struct ggml_backend_opencl_buffer_context {
|
||||
return extra;
|
||||
}
|
||||
|
||||
ggml_tensor_extra_cl_iq4_nl * ggml_opencl_alloc_temp_tensor_extra_iq4_nl() {
|
||||
ggml_tensor_extra_cl_iq4_nl * extra;
|
||||
if (temp_tensor_extras_iq4_nl.empty()) {
|
||||
extra = new ggml_tensor_extra_cl_iq4_nl();
|
||||
} else {
|
||||
extra = temp_tensor_extras_iq4_nl.back();
|
||||
temp_tensor_extras_iq4_nl.pop_back();
|
||||
}
|
||||
|
||||
temp_tensor_extras_iq4_nl_in_use.push_back(extra);
|
||||
|
||||
extra->reset();
|
||||
return extra;
|
||||
}
|
||||
|
||||
ggml_tensor_extra_cl_q4_K * ggml_opencl_alloc_temp_tensor_extra_q4_K() {
|
||||
ggml_tensor_extra_cl_q4_K * extra;
|
||||
if (temp_tensor_extras_q4_K.empty()) {
|
||||
@@ -4461,6 +4609,11 @@ struct ggml_backend_opencl_buffer_context {
|
||||
}
|
||||
temp_tensor_extras_q8_0_in_use.clear();
|
||||
|
||||
for (ggml_tensor_extra_cl_iq4_nl * e : temp_tensor_extras_iq4_nl_in_use) {
|
||||
temp_tensor_extras_iq4_nl.push_back(e);
|
||||
}
|
||||
temp_tensor_extras_iq4_nl_in_use.clear();
|
||||
|
||||
for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K_in_use) {
|
||||
temp_tensor_extras_q4_K.push_back(e);
|
||||
}
|
||||
@@ -4492,6 +4645,8 @@ struct ggml_backend_opencl_buffer_context {
|
||||
std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4_in_use;
|
||||
std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0;
|
||||
std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0_in_use;
|
||||
std::vector<ggml_tensor_extra_cl_iq4_nl *> temp_tensor_extras_iq4_nl;
|
||||
std::vector<ggml_tensor_extra_cl_iq4_nl *> temp_tensor_extras_iq4_nl_in_use;
|
||||
std::vector<ggml_tensor_extra_cl_q4_K *> temp_tensor_extras_q4_K;
|
||||
std::vector<ggml_tensor_extra_cl_q4_K *> temp_tensor_extras_q4_K_in_use;
|
||||
std::vector<ggml_tensor_extra_cl_q5_K *> temp_tensor_extras_q5_K;
|
||||
@@ -5123,6 +5278,87 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
|
||||
return;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_IQ4_NL) {
|
||||
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
|
||||
GGML_ASSERT(extra_orig && "Tensors in OpenCL backend should have been allocated and initialized");
|
||||
|
||||
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
|
||||
ggml_tensor_extra_cl_iq4_nl * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_iq4_nl();
|
||||
|
||||
size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
|
||||
size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(ggml_blck_size(tensor->type)/2);
|
||||
GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size");
|
||||
|
||||
cl_int err;
|
||||
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
|
||||
ggml_nbytes(tensor), NULL, &err);
|
||||
CL_CHECK(err);
|
||||
CL_CHECK(clEnqueueWriteBuffer(
|
||||
queue, data_device, CL_TRUE, 0,
|
||||
ggml_nbytes(tensor), data, 0, NULL, NULL));
|
||||
|
||||
cl_buffer_region region;
|
||||
|
||||
// Create subbuffer for scales.
|
||||
region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);
|
||||
region.size = size_d;
|
||||
extra->d = clCreateSubBuffer(
|
||||
extra_orig->data_device, CL_MEM_READ_WRITE,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
auto previous_origin = region.origin;
|
||||
|
||||
// Create subbuffer for quants.
|
||||
region.origin = align_to(previous_origin + size_d, backend_ctx->alignment);
|
||||
region.size = size_q;
|
||||
extra->q = clCreateSubBuffer(
|
||||
extra_orig->data_device, CL_MEM_READ_WRITE,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
cl_kernel kernel = backend_ctx->kernel_convert_block_iq4_nl;
|
||||
if (use_adreno_kernels(backend_ctx, tensor)) {
|
||||
kernel = backend_ctx->kernel_convert_block_iq4_nl_noshuffle;
|
||||
}
|
||||
#else
|
||||
cl_kernel kernel = backend_ctx->kernel_convert_block_iq4_nl;
|
||||
#endif
|
||||
cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type);
|
||||
cl_uchar mask_0F = 0x0F;
|
||||
cl_uchar mask_F0 = 0xF0;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_uchar), &mask_0F));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_F0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &n_blk));
|
||||
|
||||
size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64)*64, 1, 1};
|
||||
size_t local_work_size[] = {64, 1, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
|
||||
tensor->extra = extra;
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
if (use_adreno_kernels(backend_ctx, tensor)) {
|
||||
int M = tensor->ne[1];
|
||||
int K = tensor->ne[0];
|
||||
GGML_ASSERT(K % 32 == 0);
|
||||
|
||||
// Transpose q as ushort
|
||||
transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M);
|
||||
// Transpose d as ushort
|
||||
transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M);
|
||||
}
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_Q4_K) {
|
||||
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
|
||||
GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
|
||||
@@ -5775,6 +6011,78 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
return;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_IQ4_NL) {
|
||||
ggml_tensor_extra_cl_iq4_nl * extra = (ggml_tensor_extra_cl_iq4_nl *)tensor->extra;
|
||||
|
||||
cl_int err;
|
||||
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
|
||||
ggml_nbytes(tensor), NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
if (use_adreno_kernels(backend_ctx, tensor)) {
|
||||
static ggml_cl_buffer buf_trans_q;
|
||||
static ggml_cl_buffer buf_trans_d;
|
||||
static ggml_cl_buffer buf_unpacked;
|
||||
|
||||
cl_int M = tensor->ne[1];
|
||||
cl_int K = tensor->ne[0];
|
||||
GGML_ASSERT(K % 32 == 0);
|
||||
|
||||
size_t size_q = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*(ggml_blck_size(tensor->type)/2);
|
||||
size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t);
|
||||
GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size");
|
||||
|
||||
buf_trans_q.allocate(backend_ctx->context, size_q);
|
||||
buf_trans_d.allocate(backend_ctx->context, size_d);
|
||||
buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor));
|
||||
|
||||
// transpose q, d back
|
||||
transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4);
|
||||
transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32);
|
||||
|
||||
cl_uchar mask_0F = 0x0F;
|
||||
cl_uchar mask_F0 = 0xF0;
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_restore_block_iq4_nl_noshuffle;
|
||||
cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type);
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d.buffer));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_unpacked.buffer));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_uchar), &mask_0F));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_F0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &n_blk));
|
||||
|
||||
size_t global_work_size[] = {(size_t)n_blk, 1, 1};
|
||||
size_t local_work_size[] = {1, 1, 1};
|
||||
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
|
||||
CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL));
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
cl_kernel kernel = backend_ctx->kernel_restore_block_iq4_nl;
|
||||
cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type);
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &n_blk));
|
||||
|
||||
size_t global_work_size[] = {(size_t)n_blk, 1, 1};
|
||||
size_t local_work_size[] = {1, 1, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
|
||||
global_work_size, local_work_size, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
CL_CHECK(clEnqueueReadBuffer(
|
||||
queue, data_device, CL_TRUE, offset,
|
||||
size, data, 0, NULL, NULL));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
return;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_Q4_K) {
|
||||
ggml_tensor_extra_cl_q4_K * extra = (ggml_tensor_extra_cl_q4_K *)tensor->extra;
|
||||
|
||||
@@ -9840,6 +10148,178 @@ static void ggml_cl_mul_mat_q4_1_f32_adreno(ggml_backend_t backend, const ggml_t
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
GGML_ASSERT(src1);
|
||||
GGML_ASSERT(src1->extra);
|
||||
GGML_ASSERT(dst);
|
||||
GGML_ASSERT(dst->extra);
|
||||
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
ggml_tensor_extra_cl_iq4_nl * extra0_iq4_nl = (ggml_tensor_extra_cl_iq4_nl *)src0->extra;
|
||||
|
||||
cl_ulong offset1 = extra1->offset + src1->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
const int ne00 = src0->ne[0];
|
||||
const int ne01 = src0->ne[1];
|
||||
|
||||
const int ne1 = dst->ne[1];
|
||||
|
||||
GGML_ASSERT(ne00 % 32 == 0);
|
||||
|
||||
cl_context context = backend_ctx->context;
|
||||
cl_kernel kernel;
|
||||
|
||||
cl_int err;
|
||||
cl_image_format img_fmt;
|
||||
cl_image_desc img_desc;
|
||||
cl_buffer_region region;
|
||||
|
||||
int M = ne01;
|
||||
int N = ne1;
|
||||
int K = ne00;
|
||||
|
||||
if (ne1 == 1) {
|
||||
cl_mem q_img = nullptr;
|
||||
cl_mem b_sub_buf = nullptr;
|
||||
cl_mem b_img = nullptr;
|
||||
|
||||
// image for q
|
||||
img_fmt = { CL_R, CL_UNSIGNED_INT32};
|
||||
memset(&img_desc, 0, sizeof(img_desc));
|
||||
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc.image_width = M * K / 2 / 4;
|
||||
img_desc.buffer = extra0_iq4_nl->q;
|
||||
CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
|
||||
|
||||
// subbuffer for activations
|
||||
region.origin = offset1;
|
||||
region.size = K * N * sizeof(float);
|
||||
CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err));
|
||||
|
||||
// image for activations
|
||||
img_fmt = {CL_RGBA, CL_FLOAT};
|
||||
memset(&img_desc, 0, sizeof(img_desc));
|
||||
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc.image_width = K * N / 4;
|
||||
img_desc.buffer = b_sub_buf;
|
||||
CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
|
||||
|
||||
kernel = backend_ctx->kernel_gemv_noshuffle_iq4_nl_f32;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01));
|
||||
|
||||
size_t local_work_size[3] = {64, 4, 1};
|
||||
size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
|
||||
CL_CHECK(clReleaseMemObject(q_img));
|
||||
CL_CHECK(clReleaseMemObject(b_sub_buf));
|
||||
CL_CHECK(clReleaseMemObject(b_img));
|
||||
} else {
|
||||
cl_mem b_sub_buf = nullptr;
|
||||
cl_mem b_sub_buf_trans = nullptr;
|
||||
cl_mem b_img = nullptr;
|
||||
cl_mem b_img_trans = nullptr;
|
||||
|
||||
// subbuffer for activations
|
||||
region.origin = offset1;
|
||||
region.size = K * N * sizeof(float);
|
||||
CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err));
|
||||
|
||||
// image for activations
|
||||
img_fmt = {CL_RGBA, CL_FLOAT};
|
||||
memset(&img_desc, 0, sizeof(img_desc));
|
||||
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc.image_width = K * N / 4;
|
||||
img_desc.buffer = b_sub_buf;
|
||||
CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
|
||||
|
||||
// pad N to multiple of 8
|
||||
int extra_elements = N % 8;
|
||||
int padding = 0;
|
||||
if (extra_elements > 0){
|
||||
padding = 8 - extra_elements;
|
||||
}
|
||||
|
||||
// subbuffer for transposed activations
|
||||
region.origin = 0;
|
||||
region.size = K * (N + padding) * sizeof(float)/2;
|
||||
backend_ctx->prealloc_act_trans.allocate(context, region.size);
|
||||
CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err));
|
||||
|
||||
// image for transposed activations
|
||||
img_fmt = {CL_RGBA, CL_HALF_FLOAT};
|
||||
memset(&img_desc, 0, sizeof(img_desc));
|
||||
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc.image_width = K * (N + padding) / 4;
|
||||
img_desc.buffer = b_sub_buf_trans;
|
||||
CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err));
|
||||
|
||||
// transpose activations
|
||||
int height_B = N/4;
|
||||
if (height_B == 0) {
|
||||
height_B = 1;
|
||||
}
|
||||
int width_B = K/4;
|
||||
int padded_height_B = (N + padding)/4;
|
||||
|
||||
kernel = backend_ctx->kernel_transpose_32_16;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B));
|
||||
|
||||
size_t local_work_size_t[2] = { 1, 16 };
|
||||
size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B };
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst);
|
||||
|
||||
// gemm
|
||||
kernel = backend_ctx->kernel_gemm_noshuffle_iq4_nl_f32;
|
||||
int padded_N = N + padding;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_iq4_nl->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img_trans));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &padded_N));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne1));
|
||||
|
||||
size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1};
|
||||
size_t local_work_size[3] = {1, 128, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
|
||||
CL_CHECK(clReleaseMemObject(b_sub_buf));
|
||||
CL_CHECK(clReleaseMemObject(b_sub_buf_trans));
|
||||
CL_CHECK(clReleaseMemObject(b_img));
|
||||
CL_CHECK(clReleaseMemObject(b_img_trans));
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(backend);
|
||||
GGML_UNUSED(src0);
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
GGML_ASSERT(src0);
|
||||
@@ -10634,6 +11114,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra;
|
||||
ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra;
|
||||
ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra;
|
||||
ggml_tensor_extra_cl_iq4_nl * extra0_iq4_nl = (ggml_tensor_extra_cl_iq4_nl *)src0->extra;
|
||||
ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)src0->extra;
|
||||
ggml_tensor_extra_cl_q5_K * extra0_q5_K = (ggml_tensor_extra_cl_q5_K *)src0->extra;
|
||||
ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra;
|
||||
@@ -10738,6 +11219,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
return;
|
||||
}
|
||||
|
||||
// iq4_nl x fp32
|
||||
if (src0t == GGML_TYPE_IQ4_NL && src1t == GGML_TYPE_F32) {
|
||||
ggml_cl_mul_mat_iq4_nl_f32_adreno(backend, src0, src1, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
// q8_0 x fp32
|
||||
if (src0t == GGML_TYPE_Q8_0 && src1t == GGML_TYPE_F32 &&
|
||||
enable_adreno_trans_weight(backend_ctx, src0)) {
|
||||
@@ -11302,6 +11789,48 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
return;
|
||||
}
|
||||
case GGML_TYPE_IQ4_NL: {
|
||||
if (ne11 < 32) {
|
||||
break;
|
||||
}
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
|
||||
break;
|
||||
}
|
||||
|
||||
kernel = backend_ctx->kernel_mul_mm_iq4_nl_f32_l4_lm;
|
||||
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
|
||||
|
||||
int batch_stride_a = ne00*ne01;
|
||||
int batch_stride_b = ne10*ne11;
|
||||
int batch_stride_d = ne0*ne1;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_iq4_nl->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
|
||||
|
||||
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
|
||||
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
|
||||
size_t local_work_size[] = {(size_t)nth0, 1, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
return;
|
||||
}
|
||||
case GGML_TYPE_Q4_K: {
|
||||
if (ne11 < 32) {
|
||||
break;
|
||||
@@ -11829,6 +12358,70 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
|
||||
#endif // GGML_OPENCL_SOA_Q
|
||||
break;
|
||||
}
|
||||
case GGML_TYPE_IQ4_NL: {
|
||||
#ifdef GGML_OPENCL_SOA_Q
|
||||
kernel = backend_ctx->kernel_mul_mv_iq4_nl_f32_flat;
|
||||
|
||||
if (backend_ctx->gpu_family == INTEL) {
|
||||
nth0 = 16;
|
||||
nth1 = 1;
|
||||
ndst = 8;
|
||||
} else if (backend_ctx->gpu_family == ADRENO) {
|
||||
nth0 = 64;
|
||||
nth1 = 1;
|
||||
ndst = 8;
|
||||
} else {
|
||||
GGML_ASSERT(false && "TODO: Unknown GPU");
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_iq4_nl->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3));
|
||||
#else
|
||||
kernel = backend_ctx->kernel_mul_mv_iq4_nl_f32;
|
||||
|
||||
if (backend_ctx->gpu_family == INTEL) {
|
||||
nth0 = 16;
|
||||
nth1 = 1;
|
||||
ndst = 4;
|
||||
} else if (backend_ctx->gpu_family == ADRENO) {
|
||||
nth0 = 64;
|
||||
nth1 = 1;
|
||||
ndst = 4;
|
||||
} else {
|
||||
GGML_ASSERT(false && "TODO: Unknown GPU");
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3));
|
||||
#endif // GGML_OPENCL_SOA_Q
|
||||
break;
|
||||
}
|
||||
@@ -12131,6 +12724,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 ||
|
||||
src0t == GGML_TYPE_Q4_1 ||
|
||||
src0t == GGML_TYPE_Q8_0 ||
|
||||
src0t == GGML_TYPE_IQ4_NL ||
|
||||
src0t == GGML_TYPE_Q2_K) {
|
||||
// Each SIMD group produces N_DST values in the result. Assuming each
|
||||
// workgroup has N_SIMDGROUP SIMD groups, then each workgroup will
|
||||
|
||||
@@ -87,6 +87,17 @@ struct block_q6_K {
|
||||
half d; // super-block scale
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_iq4_nl
|
||||
//------------------------------------------------------------------------------
|
||||
#define QK4_NL 32
|
||||
|
||||
struct block_iq4_nl
|
||||
{
|
||||
half d;
|
||||
uint8_t qs[QK4_NL / 2];
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// kernel_convert_block_q4_0
|
||||
// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA).
|
||||
@@ -895,3 +906,99 @@ kernel void kernel_restore_block_q6_K_noshuffle(
|
||||
b->scales[i] = s[i];
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// kernel_convert_block_iq4_nl
|
||||
// Convert the block_iq4_nl format to 2 separate arrays (AOS -> SOA).
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_convert_block_iq4_nl(
|
||||
global struct block_iq4_nl * src0,
|
||||
global uchar * dst_q,
|
||||
global half * dst_d,
|
||||
uchar mask_0F,
|
||||
uchar mask_F0,
|
||||
ulong n_blk
|
||||
) {
|
||||
if (get_global_id(0) >= n_blk) {
|
||||
return;
|
||||
}
|
||||
global struct block_iq4_nl * b = (global struct block_iq4_nl *) src0 + get_global_id(0);
|
||||
global uchar * q = (global uchar *) dst_q + QK4_NL/2*get_global_id(0);
|
||||
global half * d = (global half *) dst_d + get_global_id(0);
|
||||
|
||||
*d = b->d;
|
||||
|
||||
for (int i = 0; i < QK4_NL/2; ++i) {
|
||||
q[i] = b->qs[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_iq4_nl(
|
||||
global uchar * src_q,
|
||||
global half * src_d,
|
||||
global struct block_iq4_nl * dst,
|
||||
ulong n_blk
|
||||
) {
|
||||
if (get_global_id(0) >= n_blk) {
|
||||
return;
|
||||
}
|
||||
global struct block_iq4_nl * b = (global struct block_iq4_nl *) dst + get_global_id(0);
|
||||
global uchar * q = (global uchar *) src_q + QK4_NL/2*get_global_id(0);
|
||||
global half * d = (global half *) src_d + get_global_id(0);
|
||||
|
||||
b->d = *d;
|
||||
|
||||
for (int i = 0; i < QK4_NL/2; ++i) {
|
||||
b->qs[i] = q[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_convert_block_iq4_nl_noshuffle(
|
||||
global struct block_iq4_nl * src0,
|
||||
global uchar * dst_q,
|
||||
global half * dst_d,
|
||||
uchar mask_0F,
|
||||
uchar mask_F0,
|
||||
ulong n_blk
|
||||
) {
|
||||
if (get_global_id(0) >= n_blk) {
|
||||
return;
|
||||
}
|
||||
global struct block_iq4_nl * b = (global struct block_iq4_nl *) src0 + get_global_id(0);
|
||||
global uchar * q = (global uchar *) dst_q + QK4_NL/2*get_global_id(0);
|
||||
global half * d = (global half *) dst_d + get_global_id(0);
|
||||
|
||||
*d = b->d;
|
||||
for (int i = 0; i < QK4_NL/4; ++i) {
|
||||
uchar x0 = b->qs[2*i + 0];
|
||||
uchar x1 = b->qs[2*i + 1];
|
||||
|
||||
q[i + 0 ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4);
|
||||
q[i + QK4_NL/4] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0);
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_iq4_nl_noshuffle(
|
||||
global uchar * src_q,
|
||||
global half * src_d,
|
||||
global struct block_iq4_nl * dst,
|
||||
uchar mask_0F,
|
||||
uchar mask_F0,
|
||||
ulong n_blk
|
||||
) {
|
||||
if (get_global_id(0) >= n_blk) {
|
||||
return;
|
||||
}
|
||||
global struct block_iq4_nl * b = (global struct block_iq4_nl *) dst + get_global_id(0);
|
||||
global uchar * q = (global uchar *) src_q + QK4_NL/2*get_global_id(0);
|
||||
global half * d = (global half *) src_d + get_global_id(0);
|
||||
|
||||
b->d = *d;
|
||||
for (int i = 0; i < QK4_NL/4; ++i) {
|
||||
uchar x0 = q[i + 0 ];
|
||||
uchar x1 = q[i + QK4_NL/4];
|
||||
|
||||
b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4));
|
||||
b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0));
|
||||
}
|
||||
}
|
||||
|
||||
150
ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl
Normal file
150
ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl
Normal file
@@ -0,0 +1,150 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
|
||||
#ifdef cl_qcom_reqd_sub_group_size
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
constant half kvalues_iq4nl[16] = {
|
||||
(half)-127.f, (half)-104.f, (half)-83.f, (half)-65.f,
|
||||
(half) -49.f, (half) -35.f, (half)-22.f, (half)-10.f,
|
||||
(half) 1.f, (half) 13.f, (half) 25.f, (half) 38.f,
|
||||
(half) 53.f, (half) 69.f, (half) 89.f, (half)113.f
|
||||
};
|
||||
|
||||
// Packed LUT: 2 FP16 values per uint, 8 unique constant loads instead of 16
|
||||
constant uint iq4nl_packed[8] = {
|
||||
0xD680D7F0u, // idx 0,1: -127, -104
|
||||
0xD410D530u, // idx 2,3: -83, -65
|
||||
0xD060D220u, // idx 4,5: -49, -35
|
||||
0xC900CD80u, // idx 6,7: -22, -10
|
||||
0x4A803C00u, // idx 8,9: 1, 13
|
||||
0x50C04E40u, // idx 10,11: 25, 38
|
||||
0x545052A0u, // idx 12,13: 53, 69
|
||||
0x57105590u // idx 14,15: 89, 113
|
||||
};
|
||||
|
||||
// Packed dequant: 1 uint constant load (8-way divergence) + shift + as_half
|
||||
#define IQ4_NL_DEQUANT(nibble) as_half((ushort)(iq4nl_packed[(nibble) >> 1] >> (((nibble) & 1u) << 4)))
|
||||
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_128
|
||||
#endif
|
||||
|
||||
kernel void kernel_gemm_noshuffle_iq4_nl_f32(
|
||||
global const ushort * src0_q,
|
||||
global const half * src0_d,
|
||||
read_only image1d_buffer_t src1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
int n_no_padding
|
||||
) {
|
||||
dst = (global float *)((global char *)dst + offsetd);
|
||||
|
||||
int m_4 = m >> 2;
|
||||
int n_4 = n >> 2;
|
||||
|
||||
int gy = get_global_id(0);
|
||||
int gx = get_global_id(1);
|
||||
int gx_2 = gx << 2;
|
||||
|
||||
half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0;
|
||||
half8 B;
|
||||
half4 dequantized_weights;
|
||||
|
||||
global const ushort * weight_ptr = src0_q + gx_2;
|
||||
global const half * scale_ptr = src0_d + gx_2;
|
||||
|
||||
for (int i = 0; i < k; i += 4) {
|
||||
B.s0123 = read_imageh(src1, gy*2 + (i)*(n_4));
|
||||
B.s4567 = read_imageh(src1, gy*2 + (i)*(n_4)+1);
|
||||
|
||||
ushort4 bits4 = vload4(0, weight_ptr + (i/4)*(m));
|
||||
|
||||
half4 scale = vload4(0, scale_ptr + (i/32)*(m));
|
||||
|
||||
// j=0
|
||||
dequantized_weights.s0 = IQ4_NL_DEQUANT(bits4.s0 & 0x000Fu) * scale.s0;
|
||||
dequantized_weights.s1 = IQ4_NL_DEQUANT(bits4.s1 & 0x000Fu) * scale.s1;
|
||||
dequantized_weights.s2 = IQ4_NL_DEQUANT(bits4.s2 & 0x000Fu) * scale.s2;
|
||||
dequantized_weights.s3 = IQ4_NL_DEQUANT(bits4.s3 & 0x000Fu) * scale.s3;
|
||||
c0 += B * dequantized_weights.s0;
|
||||
c1 += B * dequantized_weights.s1;
|
||||
c2 += B * dequantized_weights.s2;
|
||||
c3 += B * dequantized_weights.s3;
|
||||
|
||||
// j=1
|
||||
B.s0123 = read_imageh(src1, gy*2 + (i+1)*(n_4));
|
||||
B.s4567 = read_imageh(src1, gy*2 + (i+1)*(n_4)+1);
|
||||
dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 4) & 0x000Fu) * scale.s0;
|
||||
dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 4) & 0x000Fu) * scale.s1;
|
||||
dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 4) & 0x000Fu) * scale.s2;
|
||||
dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 4) & 0x000Fu) * scale.s3;
|
||||
c0 += B * dequantized_weights.s0;
|
||||
c1 += B * dequantized_weights.s1;
|
||||
c2 += B * dequantized_weights.s2;
|
||||
c3 += B * dequantized_weights.s3;
|
||||
|
||||
// j=2
|
||||
B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4));
|
||||
B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1);
|
||||
dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 8) & 0x000Fu) * scale.s0;
|
||||
dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 8) & 0x000Fu) * scale.s1;
|
||||
dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 8) & 0x000Fu) * scale.s2;
|
||||
dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 8) & 0x000Fu) * scale.s3;
|
||||
c0 += B * dequantized_weights.s0;
|
||||
c1 += B * dequantized_weights.s1;
|
||||
c2 += B * dequantized_weights.s2;
|
||||
c3 += B * dequantized_weights.s3;
|
||||
|
||||
// j=3
|
||||
B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4));
|
||||
B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1);
|
||||
dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 12) & 0x000Fu) * scale.s0;
|
||||
dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 12) & 0x000Fu) * scale.s1;
|
||||
dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 12) & 0x000Fu) * scale.s2;
|
||||
dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 12) & 0x000Fu) * scale.s3;
|
||||
c0 += B * dequantized_weights.s0;
|
||||
c1 += B * dequantized_weights.s1;
|
||||
c2 += B * dequantized_weights.s2;
|
||||
c3 += B * dequantized_weights.s3;
|
||||
}
|
||||
|
||||
int idx = (gy<<3)*m + (gx<<2);
|
||||
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
|
||||
}
|
||||
}
|
||||
302
ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl
Normal file
302
ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl
Normal file
@@ -0,0 +1,302 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
|
||||
#ifdef cl_qcom_reqd_sub_group_size
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#endif
|
||||
|
||||
#define QK4_NL 32
|
||||
#define NSUBGROUPS 4
|
||||
#define SUBGROUP_SIZE 64
|
||||
|
||||
constant half kvalues_iq4nl[16] = {
|
||||
(half)-127.f, (half)-104.f, (half)-83.f, (half)-65.f,
|
||||
(half) -49.f, (half) -35.f, (half)-22.f, (half)-10.f,
|
||||
(half) 1.f, (half) 13.f, (half) 25.f, (half) 38.f,
|
||||
(half) 53.f, (half) 69.f, (half) 89.f, (half)113.f
|
||||
};
|
||||
|
||||
// Packed LUT: 2 FP16 values per uint, 8 unique constant loads instead of 16
|
||||
constant uint iq4nl_packed[8] = {
|
||||
0xD680D7F0u, // idx 0,1: -127, -104
|
||||
0xD410D530u, // idx 2,3: -83, -65
|
||||
0xD060D220u, // idx 4,5: -49, -35
|
||||
0xC900CD80u, // idx 6,7: -22, -10
|
||||
0x4A803C00u, // idx 8,9: 1, 13
|
||||
0x50C04E40u, // idx 10,11: 25, 38
|
||||
0x545052A0u, // idx 12,13: 53, 69
|
||||
0x57105590u // idx 14,15: 89, 113
|
||||
};
|
||||
|
||||
// Packed dequant: 1 uint constant load (8-way divergence) + shift + as_half
|
||||
#define IQ4_NL_DEQUANT(nibble) as_half((ushort)(iq4nl_packed[(nibble) >> 1] >> (((nibble) & 1u) << 4)))
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \
|
||||
float shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s0, 0); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT((bits4.s0 & 0x000F)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT((bits4.s1 & 0x000F)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 0); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x00F0) >> 4)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x00F0) >> 4)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 0); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x0F00) >> 8)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x0F00) >> 8)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 0); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0xF000) >> 12)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0xF000) >> 12)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s4, 0); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT((bits4.s2 & 0x000F)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT((bits4.s3 & 0x000F)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 0); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x00F0) >> 4)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x00F0) >> 4)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 0); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x0F00) >> 8)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x0F00) >> 8)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 0); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0xF000) >> 12)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0xF000) >> 12)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s0, 1); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT((bits4.s4 & 0x000F)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT((bits4.s5 & 0x000F)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 1); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x00F0) >> 4)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x00F0) >> 4)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 1); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x0F00) >> 8)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x0F00) >> 8)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 1); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0xF000) >> 12)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0xF000) >> 12)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s4, 1); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT((bits4.s6 & 0x000F)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT((bits4.s7 & 0x000F)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 1); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x00F0) >> 4)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x00F0) >> 4)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 1); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x0F00) >> 8)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x0F00) >> 8)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 1); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0xF000) >> 12)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0xF000) >> 12)) * scale.s1 * shared_y; \
|
||||
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \
|
||||
shared_y = sub_group_broadcast(y.s0, 2); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT((bits4.s0 & 0x000F)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT((bits4.s1 & 0x000F)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 2); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x00F0) >> 4)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x00F0) >> 4)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 2); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x0F00) >> 8)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x0F00) >> 8)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 2); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0xF000) >> 12)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0xF000) >> 12)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s4, 2); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT((bits4.s2 & 0x000F)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT((bits4.s3 & 0x000F)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 2); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x00F0) >> 4)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x00F0) >> 4)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 2); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x0F00) >> 8)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x0F00) >> 8)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 2); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0xF000) >> 12)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0xF000) >> 12)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s0, 3); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT((bits4.s4 & 0x000F)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT((bits4.s5 & 0x000F)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 3); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x00F0) >> 4)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x00F0) >> 4)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 3); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x0F00) >> 8)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x0F00) >> 8)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 3); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0xF000) >> 12)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0xF000) >> 12)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s4, 3); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT((bits4.s6 & 0x000F)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT((bits4.s7 & 0x000F)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 3); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x00F0) >> 4)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x00F0) >> 4)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 3); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x0F00) >> 8)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x0F00) >> 8)) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 3); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0xF000) >> 12)) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0xF000) >> 12)) * scale.s1 * shared_y; \
|
||||
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \
|
||||
float8 shared_y; \
|
||||
shared_y = sub_group_broadcast(y, 0); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT((bits4.s0 & 0x000F)) * scale.s0 * shared_y.s0; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x00F0) >> 4)) * scale.s0 * shared_y.s1; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x0F00) >> 8)) * scale.s0 * shared_y.s2; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0xF000) >> 12)) * scale.s0 * shared_y.s3; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT((bits4.s2 & 0x000F)) * scale.s0 * shared_y.s4; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x00F0) >> 4)) * scale.s0 * shared_y.s5; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x0F00) >> 8)) * scale.s0 * shared_y.s6; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0xF000) >> 12)) * scale.s0 * shared_y.s7; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT((bits4.s1 & 0x000F)) * scale.s1 * shared_y.s0; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x00F0) >> 4)) * scale.s1 * shared_y.s1; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x0F00) >> 8)) * scale.s1 * shared_y.s2; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0xF000) >> 12)) * scale.s1 * shared_y.s3; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT((bits4.s3 & 0x000F)) * scale.s1 * shared_y.s4; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x00F0) >> 4)) * scale.s1 * shared_y.s5; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x0F00) >> 8)) * scale.s1 * shared_y.s6; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0xF000) >> 12)) * scale.s1 * shared_y.s7; \
|
||||
shared_y = sub_group_broadcast(y, 1); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT((bits4.s4 & 0x000F)) * scale.s0 * shared_y.s0; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x00F0) >> 4)) * scale.s0 * shared_y.s1; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x0F00) >> 8)) * scale.s0 * shared_y.s2; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0xF000) >> 12)) * scale.s0 * shared_y.s3; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT((bits4.s6 & 0x000F)) * scale.s0 * shared_y.s4; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x00F0) >> 4)) * scale.s0 * shared_y.s5; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x0F00) >> 8)) * scale.s0 * shared_y.s6; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0xF000) >> 12)) * scale.s0 * shared_y.s7; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT((bits4.s5 & 0x000F)) * scale.s1 * shared_y.s0; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x00F0) >> 4)) * scale.s1 * shared_y.s1; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x0F00) >> 8)) * scale.s1 * shared_y.s2; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0xF000) >> 12)) * scale.s1 * shared_y.s3; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT((bits4.s7 & 0x000F)) * scale.s1 * shared_y.s4; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x00F0) >> 4)) * scale.s1 * shared_y.s5; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x0F00) >> 8)) * scale.s1 * shared_y.s6; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0xF000) >> 12)) * scale.s1 * shared_y.s7; \
|
||||
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \
|
||||
shared_y = sub_group_broadcast(y, 2); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT((bits4.s0 & 0x000F)) * scale.s0 * shared_y.s0; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x00F0) >> 4)) * scale.s0 * shared_y.s1; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x0F00) >> 8)) * scale.s0 * shared_y.s2; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0xF000) >> 12)) * scale.s0 * shared_y.s3; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT((bits4.s2 & 0x000F)) * scale.s0 * shared_y.s4; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x00F0) >> 4)) * scale.s0 * shared_y.s5; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x0F00) >> 8)) * scale.s0 * shared_y.s6; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0xF000) >> 12)) * scale.s0 * shared_y.s7; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT((bits4.s1 & 0x000F)) * scale.s1 * shared_y.s0; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x00F0) >> 4)) * scale.s1 * shared_y.s1; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x0F00) >> 8)) * scale.s1 * shared_y.s2; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0xF000) >> 12)) * scale.s1 * shared_y.s3; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT((bits4.s3 & 0x000F)) * scale.s1 * shared_y.s4; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x00F0) >> 4)) * scale.s1 * shared_y.s5; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x0F00) >> 8)) * scale.s1 * shared_y.s6; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0xF000) >> 12)) * scale.s1 * shared_y.s7; \
|
||||
shared_y = sub_group_broadcast(y, 3); \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT((bits4.s4 & 0x000F)) * scale.s0 * shared_y.s0; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x00F0) >> 4)) * scale.s0 * shared_y.s1; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x0F00) >> 8)) * scale.s0 * shared_y.s2; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0xF000) >> 12)) * scale.s0 * shared_y.s3; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT((bits4.s6 & 0x000F)) * scale.s0 * shared_y.s4; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x00F0) >> 4)) * scale.s0 * shared_y.s5; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x0F00) >> 8)) * scale.s0 * shared_y.s6; \
|
||||
total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0xF000) >> 12)) * scale.s0 * shared_y.s7; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT((bits4.s5 & 0x000F)) * scale.s1 * shared_y.s0; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x00F0) >> 4)) * scale.s1 * shared_y.s1; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x0F00) >> 8)) * scale.s1 * shared_y.s2; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0xF000) >> 12)) * scale.s1 * shared_y.s3; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT((bits4.s7 & 0x000F)) * scale.s1 * shared_y.s4; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x00F0) >> 4)) * scale.s1 * shared_y.s5; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x0F00) >> 8)) * scale.s1 * shared_y.s6; \
|
||||
total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0xF000) >> 12)) * scale.s1 * shared_y.s7; \
|
||||
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_gemv_noshuffle_iq4_nl_f32(
|
||||
read_only image1d_buffer_t src0_q,
|
||||
global half2 * src0_d,
|
||||
read_only image1d_buffer_t src1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01)
|
||||
{
|
||||
uint groupId = get_local_id(1);
|
||||
uint gid = get_global_id(0);
|
||||
ushort slid = get_sub_group_local_id();
|
||||
|
||||
uint K = ne00;
|
||||
uint M = ne01;
|
||||
|
||||
uint LINE_STRIDE_A = M / 2;
|
||||
uint BLOCK_STRIDE_A = NSUBGROUPS * M;
|
||||
|
||||
private uint4 regA;
|
||||
private half2 regS;
|
||||
private float8 regB;
|
||||
|
||||
private float2 totalSum = (float2)(0.0f);
|
||||
|
||||
// loop along K in block granularity, skip 4 blocks every iter
|
||||
for (uint k = groupId; k < (K / QK4_NL); k += NSUBGROUPS) {
|
||||
regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows
|
||||
// first 4 fibers in each wave load 8 B values to its private scope
|
||||
if (slid < 4) {
|
||||
regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
|
||||
regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
|
||||
}
|
||||
|
||||
// load half weights for two blocks in consecutive rows
|
||||
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
|
||||
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
|
||||
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
|
||||
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
|
||||
#ifdef VECTOR_SUB_GROUP_BROADCAST
|
||||
dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB);
|
||||
#else
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB);
|
||||
#endif // VECTOR_SUB_GROUP_BROADCAST
|
||||
|
||||
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
|
||||
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
|
||||
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
|
||||
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
|
||||
#ifdef VECTOR_SUB_GROUP_BROADCAST
|
||||
dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB);
|
||||
#else
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB);
|
||||
#endif // VECTOR_SUB_GROUP_BROADCAST
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #wave=4
|
||||
local float2 reduceLM[SUBGROUP_SIZE * 3];
|
||||
if (groupId == 1) {
|
||||
reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum;
|
||||
}
|
||||
if (groupId == 2) {
|
||||
reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum;
|
||||
}
|
||||
if (groupId == 3) {
|
||||
reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if (groupId == 0) {
|
||||
totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid];
|
||||
}
|
||||
if (groupId == 0) {
|
||||
totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid];
|
||||
}
|
||||
if (groupId == 0) {
|
||||
totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid];
|
||||
}
|
||||
|
||||
// 2 outputs per fiber in wave 0
|
||||
if (groupId == 0) {
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
vstore2(totalSum, 0, &(dst[gid * 2]));
|
||||
}
|
||||
|
||||
}
|
||||
171
ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl
Normal file
171
ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl
Normal file
@@ -0,0 +1,171 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#define LOAD_VEC_A 8
|
||||
#define LOAD_VEC_B 4
|
||||
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#define BK 32
|
||||
#define TM 4
|
||||
#define TN 8
|
||||
|
||||
constant float kvalues_iq4nl[16] = {
|
||||
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f,
|
||||
1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
||||
};
|
||||
|
||||
kernel void kernel_mul_mm_iq4_nl_f32_l4_lm(
|
||||
global uchar4 * src0_q,
|
||||
global half * src0_d,
|
||||
global float4 * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne11,
|
||||
int ne12,
|
||||
|
||||
int stride_a,
|
||||
int stride_b,
|
||||
int stride_d,
|
||||
|
||||
int batch_stride_a,
|
||||
int batch_stride_b,
|
||||
int batch_stride_d,
|
||||
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global float4*)((global char*)src1 + offset1);
|
||||
dst = (global float *)((global char*)dst + offsetd);
|
||||
|
||||
local float buf_a[BM * BK];
|
||||
local float buf_b[BN * BK];
|
||||
|
||||
const int batch_idx = get_global_id(2);
|
||||
|
||||
const int i13 = batch_idx / ne12;
|
||||
const int i12 = batch_idx % ne12;
|
||||
|
||||
const int i03 = i13 / r3;
|
||||
const int i02 = i12 / r2;
|
||||
|
||||
const int batch_idx_a = i03 * ne02 + i02;
|
||||
|
||||
const int ir = get_group_id(0);
|
||||
const int ic = get_group_id(1);
|
||||
|
||||
const int tid = get_local_id(0);
|
||||
const int th_r = tid % (BM / TM);
|
||||
const int th_c = tid / (BM / TM);
|
||||
|
||||
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
|
||||
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
|
||||
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
|
||||
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
|
||||
|
||||
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
|
||||
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
|
||||
|
||||
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
|
||||
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
|
||||
|
||||
float sums[TM * TN];
|
||||
float cache_a[TM];
|
||||
float cache_b[TN];
|
||||
|
||||
for (int i = 0; i < TM * TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
for (int block = 0; block < ne00; block += BK) {
|
||||
for (int l = 0; l < BM; l += loadstride_a) {
|
||||
if (ir*BM + loadc_a + l < ne01) {
|
||||
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
|
||||
int ib = idx / 4;
|
||||
int iqs = idx % 4;
|
||||
|
||||
float d = (float)src0_d[ib];
|
||||
global uchar4 * qs = src0_q + ib*4 + iqs;
|
||||
uchar4 q = *qs;
|
||||
// IQ4_NL: use lookup table instead of linear (nibble - 8)
|
||||
float4 v1 = (float4)(kvalues_iq4nl[(q.s0 )&0x0F], kvalues_iq4nl[(q.s1 )&0x0F],
|
||||
kvalues_iq4nl[(q.s2 )&0x0F], kvalues_iq4nl[(q.s3 )&0x0F])*d;
|
||||
float4 v2 = (float4)(kvalues_iq4nl[(q.s0>>4)&0x0F], kvalues_iq4nl[(q.s1>>4)&0x0F],
|
||||
kvalues_iq4nl[(q.s2>>4)&0x0F], kvalues_iq4nl[(q.s3>>4)&0x0F])*d;
|
||||
|
||||
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0;
|
||||
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1;
|
||||
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2;
|
||||
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3;
|
||||
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
|
||||
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
|
||||
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
|
||||
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
|
||||
} else {
|
||||
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < BN; l += loadstride_b) {
|
||||
if (ic*BN + loadc_b + l < ne11) {
|
||||
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
|
||||
} else {
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
pos_a += BK / LOAD_VEC_A;
|
||||
pos_b += BK / LOAD_VEC_B;
|
||||
|
||||
for (int i = 0; i < BK; i++) {
|
||||
for (int j = 0; j < TM; j++) {
|
||||
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
|
||||
}
|
||||
|
||||
for (int j = 0; j < TN; j++) {
|
||||
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
|
||||
}
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
const int sums_idx = cc*TM + cr;
|
||||
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
const int dr = ir * BM + th_r * TM;
|
||||
const int dc = ic * BN + th_c * TN;
|
||||
|
||||
const int offsets = batch_idx * batch_stride_d;
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr + cr < ne01 && dc + cc < ne11) {
|
||||
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
164
ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl
Normal file
164
ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl
Normal file
@@ -0,0 +1,164 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define QK4_NL 32
|
||||
|
||||
typedef char int8_t;
|
||||
typedef uchar uint8_t;
|
||||
typedef short int16_t;
|
||||
typedef ushort uint16_t;
|
||||
typedef int int32_t;
|
||||
typedef uint uint32_t;
|
||||
|
||||
constant float kvalues_iq4nl[16] = {
|
||||
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f,
|
||||
1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_iq4_nl
|
||||
//------------------------------------------------------------------------------
|
||||
struct block_iq4_nl
|
||||
{
|
||||
half d;
|
||||
uint8_t qs[QK4_NL / 2];
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// mul_vec_q_n_f32
|
||||
//------------------------------------------------------------------------------
|
||||
// Compute inner product between half a block of iq4_nl and 16 floats (yl).
|
||||
// il indicates where the quants begin (0 or 8).
|
||||
inline float block_iq4_nl_dot_y(
|
||||
global struct block_iq4_nl * qb_curr,
|
||||
private float * yl,
|
||||
int il
|
||||
) {
|
||||
float d = qb_curr->d;
|
||||
float acc = 0.f;
|
||||
global uchar * qs = qb_curr->qs + il;
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
acc += yl[i] * kvalues_iq4nl[qs[i] & 0x0F];
|
||||
acc += yl[i+8] * kvalues_iq4nl[qs[i] >> 4];
|
||||
}
|
||||
return d * acc;
|
||||
}
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_DST 4 // each subgroup group works on 4 rows
|
||||
#define N_SUBGROUP 1 // number of subgroups in a thread group
|
||||
#define N_SUBGROUP_SIZE 16 // assuming subgroup size is 16
|
||||
#elif defined (ADRENO_GPU)
|
||||
#define N_DST 4
|
||||
#define N_SUBGROUP 1
|
||||
#define N_SUBGROUP_SIZE 64
|
||||
#endif
|
||||
|
||||
inline void mul_vec_q_n_f32(
|
||||
global void * src0,
|
||||
global float * src1,
|
||||
global float * dst,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
|
||||
const ulong nb = ne00/QK4_NL;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
|
||||
int first_row = (r0 * N_SUBGROUP + get_sub_group_id()) * N_DST;
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||
|
||||
global struct block_iq4_nl * x = (global struct block_iq4_nl *) src0 + offset0;
|
||||
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float yl[16]; // src1 vector cache
|
||||
float sumf[N_DST]={0.f};
|
||||
|
||||
int ix = get_sub_group_local_id()/2;
|
||||
int il = 8*(get_sub_group_local_id()%2);
|
||||
|
||||
global float * yb = y + ix * QK4_NL + il;
|
||||
|
||||
// each thread in a SIMD group deals with half a block.
|
||||
for (int ib = ix; ib < nb; ib += N_SUBGROUP_SIZE/2) {
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
yl[i] = yb[i];
|
||||
yl[i+8] = yb[i+16];
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
sumf[row] += block_iq4_nl_dot_y(x+ib+row*nb, yl, il);
|
||||
}
|
||||
|
||||
yb += QK4_NL * (N_SUBGROUP_SIZE/2);
|
||||
}
|
||||
|
||||
float tot[N_DST] = {
|
||||
sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]),
|
||||
sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])};
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
if (get_sub_group_local_id() == 0 && first_row + row < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mv_iq4_nl_f32(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global float * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
|
||||
}
|
||||
202
ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl
Normal file
202
ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl
Normal file
@@ -0,0 +1,202 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define QK4_NL 32
|
||||
|
||||
typedef char int8_t;
|
||||
typedef uchar uint8_t;
|
||||
typedef short int16_t;
|
||||
typedef ushort uint16_t;
|
||||
typedef int int32_t;
|
||||
typedef uint uint32_t;
|
||||
|
||||
constant float kvalues_iq4nl[16] = {
|
||||
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f,
|
||||
1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_iq4_nl
|
||||
//------------------------------------------------------------------------------
|
||||
struct block_iq4_nl
|
||||
{
|
||||
half d;
|
||||
uint8_t qs[QK4_NL / 2];
|
||||
};
|
||||
|
||||
// Compute dot product between half a block of iq4_nl quants and activations.
|
||||
// x points to the quant bytes, dh points to the scale.
|
||||
// yl has 16 activation values: [0..7] for low nibbles, [8..15] for high nibbles.
|
||||
// il indicates offset into the quant bytes (0 or 8).
|
||||
inline float block_iq4_nl_dot_y_flat(
|
||||
global uchar * x,
|
||||
global half * dh,
|
||||
private float * yl,
|
||||
int il
|
||||
) {
|
||||
float d = *dh;
|
||||
global uchar * qs = x + il;
|
||||
float acc = 0.f;
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
acc += yl[i] * kvalues_iq4nl[qs[i] & 0x0F];
|
||||
acc += yl[i+8] * kvalues_iq4nl[qs[i] >> 4];
|
||||
}
|
||||
return d * acc;
|
||||
}
|
||||
|
||||
#undef N_DST
|
||||
#undef N_SIMDGROUP
|
||||
#undef N_SIMDWIDTH
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_DST 8 // each subgroup works on 8 rows
|
||||
#define N_SUBGROUP 1 // number of subgroups in a thread group
|
||||
#define N_SUBGROUP_SIZE 16 // assuming subgroup size is 16
|
||||
#elif defined (ADRENO_GPU)
|
||||
#define N_DST 8
|
||||
#define N_SUBGROUP 1
|
||||
#define N_SUBGROUP_SIZE 64
|
||||
#endif
|
||||
|
||||
inline void mul_vec_q_n_f32_8x_flat(
|
||||
global uchar * src0_q,
|
||||
global half * src0_d,
|
||||
global float * src1,
|
||||
global float * dst,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
const ulong nb = ne00/QK4_NL;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
|
||||
int first_row = (r0 * N_SUBGROUP + get_sub_group_id()) * N_DST;
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
// The number of scales is the same as the number of blocks.
|
||||
ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||
// Each block contains QK4_NL/2 uchars, hence offset for qs is as follows.
|
||||
ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_NL/2;
|
||||
|
||||
global uchar * x = (global uchar *) src0_q + offset0_q;
|
||||
global half * d = (global half *) src0_d + offset0_d;
|
||||
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float yl[16];
|
||||
float8 sumf = 0.f;
|
||||
|
||||
int ix = get_sub_group_local_id()/2;
|
||||
int il = 8*(get_sub_group_local_id()%2);
|
||||
|
||||
global float * yb = y + ix*QK4_NL + il;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += N_SUBGROUP_SIZE/2) {
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
yl[i] = yb[i];
|
||||
yl[i+8] = yb[i+16];
|
||||
}
|
||||
|
||||
sumf.s0 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 0*nb*QK4_NL/2, d + ib + 0*nb, yl, il);
|
||||
sumf.s1 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 1*nb*QK4_NL/2, d + ib + 1*nb, yl, il);
|
||||
sumf.s2 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 2*nb*QK4_NL/2, d + ib + 2*nb, yl, il);
|
||||
sumf.s3 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 3*nb*QK4_NL/2, d + ib + 3*nb, yl, il);
|
||||
|
||||
sumf.s4 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 4*nb*QK4_NL/2, d + ib + 4*nb, yl, il);
|
||||
sumf.s5 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 5*nb*QK4_NL/2, d + ib + 5*nb, yl, il);
|
||||
sumf.s6 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 6*nb*QK4_NL/2, d + ib + 6*nb, yl, il);
|
||||
sumf.s7 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 7*nb*QK4_NL/2, d + ib + 7*nb, yl, il);
|
||||
|
||||
yb += QK4_NL * (N_SUBGROUP_SIZE/2);
|
||||
}
|
||||
|
||||
float8 tot = (float8)(
|
||||
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
|
||||
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),
|
||||
sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),
|
||||
sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)
|
||||
);
|
||||
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
if (first_row + 0 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
|
||||
}
|
||||
if (first_row + 1 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
|
||||
}
|
||||
if (first_row + 2 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
|
||||
}
|
||||
if (first_row + 3 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
|
||||
}
|
||||
|
||||
if (first_row + 4 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
|
||||
}
|
||||
if (first_row + 5 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
|
||||
}
|
||||
if (first_row + 6 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
|
||||
}
|
||||
if (first_row + 7 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mv_iq4_nl_f32_flat(
|
||||
global uchar * src0_q,
|
||||
global half * src0_d,
|
||||
global float * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
|
||||
}
|
||||
@@ -224,7 +224,7 @@ struct sycl_device_info {
|
||||
// cudaOccupancyMaxActiveBlocksPerMultiprocessor
|
||||
bool vmm; // virtual memory support
|
||||
size_t total_vram;
|
||||
//sycl_hw_info hw_info; \\ device id and aarch, currently not used
|
||||
sycl_hw_info hw_info;
|
||||
optimize_feature opt_feature;
|
||||
};
|
||||
|
||||
|
||||
@@ -104,6 +104,7 @@ static ggml_sycl_device_info ggml_sycl_init() {
|
||||
|
||||
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
|
||||
info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units();
|
||||
info.devices[i].hw_info = get_device_hw_info(&device);
|
||||
|
||||
}
|
||||
|
||||
@@ -3703,9 +3704,16 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
||||
// Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
|
||||
// is enabled takes precedence over DMMV, the current if-else implementation
|
||||
// requires disabling DMMV if both conditions are met
|
||||
|
||||
if (!g_ggml_sycl_prioritize_dmmv && ((should_reorder_tensor(ctx, dst) &&
|
||||
ggml_sycl_supports_reorder_mmvq(src0->type)))) {
|
||||
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
|
||||
// Arc770 get benefit with Q4_0 by skipping it.
|
||||
if (!(ggml_sycl_info().devices[ctx.device].hw_info.arch ==
|
||||
gpu_arch::intel_gpu_acm_g10 &&
|
||||
src0->type == GGML_TYPE_Q4_0)) {
|
||||
use_dequantize_mul_mat_vec =
|
||||
use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
|
||||
}
|
||||
}
|
||||
|
||||
if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
|
||||
@@ -1,15 +1,67 @@
|
||||
#include "sycl_hw.hpp"
|
||||
|
||||
// TODO: currently not used
|
||||
/*
|
||||
sycl_hw_info get_device_hw_info(sycl::device *device_ptr) {
|
||||
sycl_hw_info res;
|
||||
int32_t id = device_ptr->get_info<sycl::ext::intel::info::device::device_id>();
|
||||
res.device_id = id;
|
||||
using namespace std;
|
||||
|
||||
syclex::architecture arch = device_ptr->get_info<syclex::info::device::architecture>();
|
||||
res.arch = arch;
|
||||
|
||||
return res;
|
||||
}
|
||||
/*defined in
|
||||
* /opt/intel/oneapi/compiler/latest/include/sycl/ext/oneapi/experimental/device_architecture.def
|
||||
*/
|
||||
static map<gpu_arch, std::pair<const char*, sycl_intel_gpu_family>> arch2name = {
|
||||
{gpu_arch::intel_gpu_bdw, {"intel_gpu_bdw", GPU_FAMILY_IGPU_NON_XE}},
|
||||
{gpu_arch::intel_gpu_skl, {"intel_gpu_skl", GPU_FAMILY_IGPU_NON_XE}},
|
||||
{gpu_arch::intel_gpu_kbl, {"intel_gpu_kbl", GPU_FAMILY_IGPU_NON_XE}},
|
||||
{gpu_arch::intel_gpu_cfl, {"intel_gpu_cfl", GPU_FAMILY_IGPU_NON_XE}},
|
||||
{gpu_arch::intel_gpu_apl, {"intel_gpu_apl", GPU_FAMILY_IGPU_NON_XE}},
|
||||
{gpu_arch::intel_gpu_glk, {"intel_gpu_glk", GPU_FAMILY_IGPU_NON_XE}},
|
||||
{gpu_arch::intel_gpu_whl, {"intel_gpu_whl", GPU_FAMILY_IGPU_NON_XE}},
|
||||
{gpu_arch::intel_gpu_aml, {"intel_gpu_aml", GPU_FAMILY_IGPU_NON_XE}},
|
||||
{gpu_arch::intel_gpu_cml, {"intel_gpu_cml", GPU_FAMILY_IGPU_NON_XE}},
|
||||
{gpu_arch::intel_gpu_icllp, {"intel_gpu_icllp", GPU_FAMILY_IGPU_NON_XE}},
|
||||
{gpu_arch::intel_gpu_ehl, {"intel_gpu_ehl", GPU_FAMILY_IGPU_NON_XE}},
|
||||
{gpu_arch::intel_gpu_tgllp, {"intel_gpu_tgllp", GPU_FAMILY_IGPU_NON_XE}},
|
||||
{gpu_arch::intel_gpu_rkl, {"intel_gpu_rkl", GPU_FAMILY_IGPU_NON_XE}},
|
||||
{gpu_arch::intel_gpu_adl_s, {"intel_gpu_adl_s", GPU_FAMILY_IGPU_NON_XE}},
|
||||
{gpu_arch::intel_gpu_adl_p, {"intel_gpu_adl_p", GPU_FAMILY_IGPU_NON_XE}},
|
||||
{gpu_arch::intel_gpu_adl_n, {"intel_gpu_adl_n", GPU_FAMILY_IGPU_NON_XE}},
|
||||
{gpu_arch::intel_gpu_dg1, {"intel_gpu_dg1", GPU_FAMILY_DGPU_CLIENT_GAME}},
|
||||
{gpu_arch::intel_gpu_acm_g10, {"intel_gpu_acm_g10", GPU_FAMILY_DGPU_CLIENT_GAME}},
|
||||
{gpu_arch::intel_gpu_acm_g11, {"intel_gpu_acm_g11", GPU_FAMILY_DGPU_CLIENT_GAME}},
|
||||
{gpu_arch::intel_gpu_acm_g12, {"intel_gpu_acm_g12", GPU_FAMILY_DGPU_CLIENT_GAME}},
|
||||
{gpu_arch::intel_gpu_pvc, {"intel_gpu_pvc", GPU_FAMILY_DGPU_CLOUD}},
|
||||
{gpu_arch::intel_gpu_pvc_vg, {"intel_gpu_pvc_vg", GPU_FAMILY_DGPU_CLOUD}},
|
||||
{gpu_arch::intel_gpu_mtl_u, {"intel_gpu_mtl_u", GPU_FAMILY_IGPU_XE}},
|
||||
{gpu_arch::intel_gpu_mtl_h, {"intel_gpu_mtl_h", GPU_FAMILY_IGPU_XE}},
|
||||
{gpu_arch::intel_gpu_arl_h, {"intel_gpu_arl_h", GPU_FAMILY_IGPU_XE}},
|
||||
{gpu_arch::intel_gpu_bmg_g21, {"intel_gpu_bmg_g21", GPU_FAMILY_DGPU_CLIENT_GAME}},
|
||||
{gpu_arch::intel_gpu_bmg_g31, {"intel_gpu_bmg_g31", GPU_FAMILY_DGPU_CLIENT_GAME}},
|
||||
{gpu_arch::intel_gpu_lnl_m, {"intel_gpu_lnl_m", GPU_FAMILY_IGPU_XE}},
|
||||
{gpu_arch::intel_gpu_ptl_h, {"intel_gpu_ptl_h", GPU_FAMILY_IGPU_XE}},
|
||||
{gpu_arch::intel_gpu_ptl_u, {"intel_gpu_ptl_u", GPU_FAMILY_IGPU_XE}},
|
||||
{gpu_arch::intel_gpu_wcl, {"intel_gpu_wcl", GPU_FAMILY_IGPU_XE}}
|
||||
};
|
||||
|
||||
|
||||
sycl_hw_info get_device_hw_info(sycl::device* device_ptr) {
|
||||
sycl_hw_info res;
|
||||
int32_t id =
|
||||
device_ptr->get_info<sycl::ext::intel::info::device::device_id>();
|
||||
res.device_id = id;
|
||||
|
||||
res.name = device_ptr->get_info<sycl::info::device::name>();
|
||||
|
||||
syclex::architecture arch =
|
||||
device_ptr->get_info<syclex::info::device::architecture>();
|
||||
res.arch = arch;
|
||||
|
||||
map<syclex::architecture,
|
||||
std::pair<const char*, sycl_intel_gpu_family>>::iterator it =
|
||||
arch2name.find(res.arch);
|
||||
if (it != arch2name.end()) {
|
||||
res.arch_name = it->second.first;
|
||||
res.gpu_family = it->second.second;
|
||||
} else {
|
||||
res.arch_name = "unknown";
|
||||
res.gpu_family = GPU_FAMILY_UKNOWN;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -9,18 +9,30 @@
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
namespace syclex = sycl::ext::oneapi::experimental;
|
||||
using gpu_arch = sycl::ext::oneapi::experimental::architecture;
|
||||
|
||||
// TODO: currently not used
|
||||
/*
|
||||
struct sycl_hw_info {
|
||||
syclex::architecture arch;
|
||||
int32_t device_id;
|
||||
// It's used to mark the GPU computing capacity
|
||||
// The value must flow the order of performance.
|
||||
enum sycl_intel_gpu_family {
|
||||
GPU_FAMILY_UKNOWN = -1,
|
||||
// iGPU without Xe core, before Meteor Lake iGPU(Xe)
|
||||
GPU_FAMILY_IGPU_NON_XE = 0,
|
||||
// iGPU with Xe core, Meteor Lake iGPU or newer.
|
||||
GPU_FAMILY_IGPU_XE = 1,
|
||||
// dGPU for gaming in client/data center (DG1/FLex 140 or newer).
|
||||
GPU_FAMILY_DGPU_CLIENT_GAME = 2,
|
||||
// dGPU for AI in cloud, PVC or newer.
|
||||
GPU_FAMILY_DGPU_CLOUD = 3
|
||||
};
|
||||
|
||||
bool is_in_vector(std::vector<int> &vec, int item);
|
||||
struct sycl_hw_info {
|
||||
syclex::architecture arch;
|
||||
const char* arch_name;
|
||||
int32_t device_id;
|
||||
std::string name;
|
||||
sycl_intel_gpu_family gpu_family;
|
||||
};
|
||||
|
||||
sycl_hw_info get_device_hw_info(sycl::device *device_ptr);
|
||||
*/
|
||||
|
||||
|
||||
#endif // SYCL_HW_HPP
|
||||
|
||||
@@ -98,6 +98,29 @@ struct ggml_webgpu_ssm_conv_shader_decisions {
|
||||
uint32_t tokens_per_wg;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_ssm_scan_pipeline_key {
|
||||
int type;
|
||||
int d_state;
|
||||
|
||||
bool operator==(const ggml_webgpu_ssm_scan_pipeline_key & other) const {
|
||||
return type == other.type && d_state == other.d_state;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_ssm_scan_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_ssm_scan_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.d_state);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_ssm_scan_shader_decisions {
|
||||
uint32_t wg_size;
|
||||
uint32_t tokens_per_tile;
|
||||
};
|
||||
|
||||
/** Argsort **/
|
||||
|
||||
struct ggml_webgpu_argsort_shader_lib_context {
|
||||
@@ -436,19 +459,27 @@ struct ggml_webgpu_unary_pipeline_key_hash {
|
||||
|
||||
/** FlashAttention */
|
||||
|
||||
enum ggml_webgpu_flash_attn_path : uint32_t {
|
||||
GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 0u,
|
||||
GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 1u,
|
||||
GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 2u,
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_pipeline_key {
|
||||
ggml_type kv_type;
|
||||
uint32_t head_dim_qk;
|
||||
uint32_t head_dim_v;
|
||||
bool kv_direct;
|
||||
bool kv_overlap;
|
||||
bool has_mask;
|
||||
bool has_sinks;
|
||||
bool uses_logit_softcap;
|
||||
uint32_t path;
|
||||
|
||||
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
|
||||
return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
|
||||
kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
|
||||
uses_logit_softcap == other.uses_logit_softcap;
|
||||
kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask &&
|
||||
has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap && path == other.path;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -459,39 +490,70 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash {
|
||||
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
|
||||
ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
||||
ggml_webgpu_hash_combine(seed, key.kv_direct);
|
||||
ggml_webgpu_hash_combine(seed, key.kv_overlap);
|
||||
ggml_webgpu_hash_combine(seed, key.has_mask);
|
||||
ggml_webgpu_hash_combine(seed, key.has_sinks);
|
||||
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
|
||||
ggml_webgpu_hash_combine(seed, key.path);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_decisions {
|
||||
uint32_t q_tile = 0;
|
||||
uint32_t kv_tile = 0;
|
||||
uint32_t wg_size = 0;
|
||||
uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX;
|
||||
uint32_t q_tile = 0;
|
||||
uint32_t kv_tile = 0;
|
||||
uint32_t wg_size = 0;
|
||||
bool kv_direct = false;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_vec_decisions {
|
||||
uint32_t kv_tile = 0;
|
||||
uint32_t wg_size = 0;
|
||||
};
|
||||
inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u;
|
||||
inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u;
|
||||
|
||||
inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) {
|
||||
if (key.path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC || key.kv_type != GGML_TYPE_F16 ||
|
||||
key.head_dim_qk != key.head_dim_v) {
|
||||
return 1u;
|
||||
}
|
||||
|
||||
switch (key.head_dim_qk) {
|
||||
case 64:
|
||||
case 192:
|
||||
case 576:
|
||||
return 2u;
|
||||
case 96:
|
||||
return 4u;
|
||||
default:
|
||||
return 1u;
|
||||
}
|
||||
}
|
||||
|
||||
inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key(
|
||||
const ggml_webgpu_shader_lib_context & context) {
|
||||
const ggml_webgpu_shader_lib_context & context,
|
||||
uint32_t path) {
|
||||
const bool has_mask = context.src3 != nullptr;
|
||||
const bool has_sinks = context.src4 != nullptr;
|
||||
const bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) &&
|
||||
(context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
||||
bool kv_direct = false;
|
||||
if (path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||
uint32_t kv_direct_align = GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH;
|
||||
if (path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) {
|
||||
kv_direct_align = context.sg_mat_k;
|
||||
}
|
||||
kv_direct = (context.src1->type == GGML_TYPE_F16) &&
|
||||
(context.src0->ne[0] % std::max(1u, kv_direct_align) == 0) &&
|
||||
(context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
||||
}
|
||||
|
||||
ggml_webgpu_flash_attn_pipeline_key key = {};
|
||||
key.kv_type = context.src1->type;
|
||||
key.head_dim_qk = (uint32_t) context.src0->ne[0];
|
||||
key.head_dim_v = (uint32_t) context.src2->ne[0];
|
||||
key.kv_direct = kv_direct;
|
||||
key.kv_overlap = context.src_overlap;
|
||||
key.has_mask = has_mask;
|
||||
key.has_sinks = has_sinks;
|
||||
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
|
||||
key.path = path;
|
||||
return key;
|
||||
}
|
||||
|
||||
@@ -554,8 +616,16 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
|
||||
|
||||
inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context,
|
||||
const ggml_webgpu_flash_attn_pipeline_key & key) {
|
||||
const size_t limit_bytes = context.wg_mem_limit_bytes;
|
||||
const size_t q_tile = context.sg_mat_m;
|
||||
const size_t limit_bytes = context.wg_mem_limit_bytes;
|
||||
uint32_t q_tile = context.sg_mat_m;
|
||||
uint32_t kv_granularity = context.sg_mat_n;
|
||||
if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||
q_tile = GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE;
|
||||
kv_granularity = std::max(1u, context.max_subgroup_size);
|
||||
} else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
q_tile = 1u;
|
||||
kv_granularity = 8u;
|
||||
}
|
||||
const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
|
||||
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
size_t bytes_per_kv = 0;
|
||||
@@ -568,23 +638,90 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_
|
||||
bytes_per_kv += q_tile;
|
||||
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
|
||||
const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
|
||||
return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
|
||||
return (max_kv_tile / kv_granularity) * kv_granularity;
|
||||
}
|
||||
|
||||
inline uint32_t ggml_webgpu_flash_attn_vec_get_kv_tile(const ggml_webgpu_shader_lib_context & context) {
|
||||
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context);
|
||||
const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key);
|
||||
uint32_t kv_tile = std::max(context.sg_mat_n, std::min(32u, min_kv_tile));
|
||||
kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n;
|
||||
inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
|
||||
const ggml_webgpu_shader_lib_context & context,
|
||||
size_t storage_offset_alignment) {
|
||||
ggml_webgpu_flash_attn_decisions decisions = {};
|
||||
const size_t alignment = std::max<size_t>(1u, storage_offset_alignment);
|
||||
const auto * K = context.src1;
|
||||
const auto * V = context.src2;
|
||||
GGML_ASSERT(K != nullptr);
|
||||
GGML_ASSERT(V != nullptr);
|
||||
|
||||
if (key.kv_direct) {
|
||||
kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
|
||||
kv_tile -= context.sg_mat_n;
|
||||
const auto flash_attn_tensor_offset = [](const ggml_tensor * tensor) -> size_t {
|
||||
constexpr uintptr_t ptr_base_addr = 0x1000u;
|
||||
const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor;
|
||||
return reinterpret_cast<uintptr_t>(base->data) - ptr_base_addr + tensor->view_offs;
|
||||
};
|
||||
|
||||
const uint32_t k_offset_elems =
|
||||
(uint32_t) ((flash_attn_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type));
|
||||
const uint32_t v_offset_elems =
|
||||
(uint32_t) ((flash_attn_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type));
|
||||
const bool f16_vec4_aligned = (k_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u) &&
|
||||
(v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u);
|
||||
const bool kv_vec_type_supported =
|
||||
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
|
||||
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) &&
|
||||
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
|
||||
kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
|
||||
(context.src2->type == K->type);
|
||||
const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 &&
|
||||
V->type == GGML_TYPE_F16 && f16_vec4_aligned &&
|
||||
(context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
|
||||
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && !use_vec;
|
||||
|
||||
decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC :
|
||||
use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE :
|
||||
GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX;
|
||||
|
||||
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path);
|
||||
decisions.kv_direct = key.kv_direct;
|
||||
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key);
|
||||
decisions.q_tile = 1u;
|
||||
decisions.kv_tile = std::max(8u, std::min(32u, min_kv_tile));
|
||||
decisions.kv_tile = (decisions.kv_tile / 8u) * 8u;
|
||||
decisions.wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
|
||||
if (decisions.kv_direct) {
|
||||
decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
|
||||
decisions.kv_tile -= 8u;
|
||||
}
|
||||
}
|
||||
return decisions;
|
||||
}
|
||||
|
||||
decisions.q_tile =
|
||||
decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE : context.sg_mat_m;
|
||||
decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
|
||||
std::min(64u, ggml_webgpu_flash_attn_max_kv_tile(context, key)) :
|
||||
std::min(ggml_webgpu_flash_attn_max_kv_tile(context, key),
|
||||
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
|
||||
decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
|
||||
GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE :
|
||||
std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
|
||||
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||
const uint32_t tile_kv_granularity = std::max(1u, context.max_subgroup_size);
|
||||
decisions.kv_tile =
|
||||
std::max(tile_kv_granularity, (decisions.kv_tile / tile_kv_granularity) * tile_kv_granularity);
|
||||
}
|
||||
|
||||
if (decisions.kv_direct) {
|
||||
GGML_ASSERT(decisions.kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
|
||||
decisions.kv_tile -= decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
|
||||
std::max(1u, context.max_subgroup_size) :
|
||||
context.sg_mat_n;
|
||||
}
|
||||
}
|
||||
|
||||
return kv_tile;
|
||||
return decisions;
|
||||
}
|
||||
|
||||
/** Matrix Multiplication **/
|
||||
@@ -807,6 +944,8 @@ class ggml_webgpu_shader_lib {
|
||||
solve_tri_pipelines; // type
|
||||
std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash>
|
||||
ssm_conv_pipelines; // type/vectorized
|
||||
std::unordered_map<ggml_webgpu_ssm_scan_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_scan_pipeline_key_hash>
|
||||
ssm_scan_pipelines; // type/d_state
|
||||
std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_gated_delta_net_pipeline_key_hash>
|
||||
@@ -821,8 +960,6 @@ class ggml_webgpu_shader_lib {
|
||||
repeat_pipelines; // type
|
||||
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
||||
flash_attn_pipelines;
|
||||
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
||||
flash_attn_vec_pipelines;
|
||||
std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash>
|
||||
@@ -1321,6 +1458,53 @@ class ggml_webgpu_shader_lib {
|
||||
return ssm_conv_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_ssm_scan_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_ssm_scan_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
key.d_state = (int) context.src0->ne[0];
|
||||
|
||||
auto it = ssm_scan_pipelines.find(key);
|
||||
if (it != ssm_scan_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "ssm_scan";
|
||||
|
||||
switch (key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
variant += "_f32";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for ssm_scan shader");
|
||||
}
|
||||
|
||||
const uint32_t wg_size = (uint32_t) key.d_state;
|
||||
|
||||
constexpr uint32_t tokens_per_tile = 4u;
|
||||
|
||||
defines.push_back("WG_SIZE=" + std::to_string(wg_size) + "u");
|
||||
defines.push_back("TOKENS_PER_TILE=" + std::to_string(tokens_per_tile) + "u");
|
||||
|
||||
if (context.supports_subgroups) {
|
||||
defines.push_back("USE_SUBGROUP_REDUCTION");
|
||||
variant += "_sg_reduce";
|
||||
} else {
|
||||
variant += "_wg_reduce";
|
||||
}
|
||||
|
||||
variant += "_d" + std::to_string(key.d_state);
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_ssm_scan, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_ssm_scan_shader_decisions>();
|
||||
decisions->wg_size = wg_size;
|
||||
decisions->tokens_per_tile = tokens_per_tile;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
ssm_scan_pipelines[key] = pipeline;
|
||||
return ssm_scan_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_gated_delta_net_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
@@ -2044,14 +2228,19 @@ class ggml_webgpu_shader_lib {
|
||||
return repeat_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context);
|
||||
auto it = flash_attn_pipelines.find(key);
|
||||
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context,
|
||||
size_t storage_offset_alignment) {
|
||||
const ggml_webgpu_flash_attn_decisions decisions =
|
||||
ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment);
|
||||
ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path);
|
||||
auto it = flash_attn_pipelines.find(key);
|
||||
if (it != flash_attn_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "flash_attn";
|
||||
std::string variant = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC ? "flash_attn_vec" :
|
||||
decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? "flash_attn_tile" :
|
||||
"flash_attn";
|
||||
|
||||
switch (key.kv_type) {
|
||||
case GGML_TYPE_F32:
|
||||
@@ -2073,7 +2262,12 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
if (key.has_mask) {
|
||||
defines.push_back("MASK");
|
||||
variant += "_mask";
|
||||
if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
defines.push_back("BLK");
|
||||
variant += "_mask_blk";
|
||||
} else {
|
||||
variant += "_mask";
|
||||
}
|
||||
}
|
||||
if (key.has_sinks) {
|
||||
defines.push_back("SINKS");
|
||||
@@ -2087,6 +2281,10 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back("KV_DIRECT");
|
||||
variant += "_kvdirect";
|
||||
}
|
||||
if (key.kv_overlap) {
|
||||
defines.push_back("KV_OVERLAP");
|
||||
variant += "_kv_overlap";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
|
||||
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
|
||||
@@ -2094,129 +2292,37 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
|
||||
|
||||
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
||||
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
||||
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
||||
|
||||
auto decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>();
|
||||
decisions->q_tile = context.sg_mat_m;
|
||||
|
||||
const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key);
|
||||
uint32_t kv_tile = std::min(min_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
|
||||
|
||||
if (key.kv_direct) {
|
||||
kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
|
||||
kv_tile -= context.sg_mat_n;
|
||||
}
|
||||
const char * shader_src = wgsl_flash_attn;
|
||||
if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
defines.push_back("KV_GRANULARITY=8");
|
||||
defines.push_back(std::string("VEC_NE=") + std::to_string(ggml_webgpu_flash_attn_pick_vec_ne(key)) + "u");
|
||||
shader_src = wgsl_flash_attn_vec_split;
|
||||
} else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||
shader_src = wgsl_flash_attn_tile;
|
||||
defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size));
|
||||
defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.head_dim_qk, key.head_dim_v)));
|
||||
variant += "_tile";
|
||||
} else {
|
||||
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
||||
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
||||
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
||||
}
|
||||
|
||||
decisions->kv_tile = kv_tile;
|
||||
decisions->wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
|
||||
|
||||
defines.push_back(std::string("Q_TILE=") + std::to_string(decisions->q_tile));
|
||||
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile));
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size));
|
||||
auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions);
|
||||
defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile));
|
||||
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile));
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size));
|
||||
|
||||
webgpu_pipeline pipeline =
|
||||
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn, defines), variant);
|
||||
pipeline.context = decisions;
|
||||
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
|
||||
pipeline.context = pipeline_decisions;
|
||||
flash_attn_pipelines[key] = pipeline;
|
||||
return flash_attn_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context);
|
||||
auto it = flash_attn_vec_pipelines.find(key);
|
||||
if (it != flash_attn_vec_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "flash_attn_vec";
|
||||
|
||||
switch (key.kv_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("KV_F32");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("KV_F16");
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
defines.push_back("KV_Q4_0");
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
defines.push_back("KV_Q8_0");
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported KV type for flash attention shader");
|
||||
}
|
||||
variant += std::string("_") + ggml_type_name(key.kv_type);
|
||||
|
||||
if (key.has_mask) {
|
||||
defines.push_back("MASK");
|
||||
defines.push_back("BLK");
|
||||
variant += "_mask_blk";
|
||||
}
|
||||
if (key.has_sinks) {
|
||||
defines.push_back("SINKS");
|
||||
variant += "_sinks";
|
||||
}
|
||||
if (key.uses_logit_softcap) {
|
||||
defines.push_back("LOGIT_SOFTCAP");
|
||||
variant += "_lgsc";
|
||||
}
|
||||
if (key.kv_direct) {
|
||||
defines.push_back("KV_DIRECT");
|
||||
variant += "_kvdirect";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
|
||||
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
|
||||
|
||||
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
||||
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
||||
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
||||
defines.push_back("Q_TILE=1");
|
||||
|
||||
auto decisions = std::make_shared<ggml_webgpu_flash_attn_vec_decisions>();
|
||||
decisions->kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context);
|
||||
decisions->wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
|
||||
uint32_t vec_ne = 1u;
|
||||
|
||||
// Keep conservative defaults unless this is the f16 vec-split shape family.
|
||||
if (key.kv_type == GGML_TYPE_F16 && key.head_dim_qk == key.head_dim_v) {
|
||||
switch (key.head_dim_qk) {
|
||||
case 64:
|
||||
case 192:
|
||||
case 576:
|
||||
vec_ne = 2u;
|
||||
break;
|
||||
case 96:
|
||||
vec_ne = 4u;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile));
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size));
|
||||
defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u");
|
||||
|
||||
webgpu_pipeline pipeline =
|
||||
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant);
|
||||
pipeline.context = decisions;
|
||||
flash_attn_vec_pipelines[key] = pipeline;
|
||||
return flash_attn_vec_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context, uint32_t kv_tile) {
|
||||
ggml_webgpu_flash_attn_blk_pipeline_key key = {};
|
||||
key.kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context);
|
||||
key.kv_tile = kv_tile;
|
||||
auto it = flash_attn_blk_pipelines.find(key);
|
||||
if (it != flash_attn_blk_pipelines.end()) {
|
||||
return it->second;
|
||||
|
||||
@@ -389,23 +389,6 @@ static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_t
|
||||
return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
|
||||
}
|
||||
|
||||
static bool ggml_webgpu_flash_attn_use_vec(webgpu_global_context & global_ctx,
|
||||
const ggml_tensor * Q,
|
||||
const ggml_tensor * K,
|
||||
const ggml_tensor * V) {
|
||||
const size_t alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
|
||||
const uint32_t k_offset_elems =
|
||||
(uint32_t) ((ggml_webgpu_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type));
|
||||
const uint32_t v_offset_elems =
|
||||
(uint32_t) ((ggml_webgpu_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type));
|
||||
const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u);
|
||||
const bool kv_vec_type_supported =
|
||||
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
|
||||
|
||||
return (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported &&
|
||||
(K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type);
|
||||
}
|
||||
|
||||
static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
|
||||
size_t offset = ggml_webgpu_tensor_offset(t);
|
||||
return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
|
||||
@@ -1132,6 +1115,80 @@ static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx,
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * src2,
|
||||
ggml_tensor * src3,
|
||||
ggml_tensor * src4,
|
||||
ggml_tensor * src5,
|
||||
ggml_tensor * src6,
|
||||
ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src0;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src3) / ggml_type_size(src3->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src6) / ggml_type_size(src6->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
|
||||
(uint32_t) (src2->nb[1] / ggml_type_size(src2->type)),
|
||||
(uint32_t) (src2->nb[2] / ggml_type_size(src2->type)),
|
||||
|
||||
(uint32_t) src3->ne[0],
|
||||
(uint32_t) (src3->nb[1] / ggml_type_size(src3->type)),
|
||||
|
||||
(uint32_t) (src4->nb[1] / ggml_type_size(src4->type)),
|
||||
(uint32_t) (src4->nb[2] / ggml_type_size(src4->type)),
|
||||
(uint32_t) (src4->nb[3] / ggml_type_size(src4->type)),
|
||||
|
||||
(uint32_t) (src5->nb[1] / ggml_type_size(src5->type)),
|
||||
(uint32_t) (src5->nb[2] / ggml_type_size(src5->type)),
|
||||
(uint32_t) (src5->nb[3] / ggml_type_size(src5->type)),
|
||||
|
||||
(uint32_t) src0->ne[0],
|
||||
(uint32_t) src0->ne[1],
|
||||
(uint32_t) src0->ne[2],
|
||||
(uint32_t) src4->ne[1],
|
||||
(uint32_t) src1->ne[2],
|
||||
(uint32_t) src1->ne[3],
|
||||
(uint32_t) ggml_nelements(src1),
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2), ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4), ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6), ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst),
|
||||
};
|
||||
|
||||
const uint32_t total_wg = (uint32_t) (src0->ne[1] * src0->ne[2] * src1->ne[3]);
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
uint32_t wg_x;
|
||||
uint32_t wg_y;
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
@@ -1567,7 +1624,6 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
|
||||
return ggml_backend_webgpu_build_multi(ctx, dispatches);
|
||||
}
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
ggml_tensor * Q,
|
||||
ggml_tensor * K,
|
||||
@@ -1585,13 +1641,29 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
||||
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
const int has_mask = (mask != nullptr);
|
||||
const int has_sinks = (sinks != nullptr);
|
||||
const int has_mask = (mask != nullptr);
|
||||
const int has_sinks = (sinks != nullptr);
|
||||
const bool kv_overlap = ggml_webgpu_tensor_overlap(K, V) && K->type == V->type;
|
||||
|
||||
uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type));
|
||||
uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type));
|
||||
size_t kv_bind_offset = 0;
|
||||
size_t kv_bind_size = 0;
|
||||
if (kv_overlap) {
|
||||
const size_t k_bind_offset = ggml_webgpu_tensor_align_offset(ctx, K);
|
||||
const size_t v_bind_offset = ggml_webgpu_tensor_align_offset(ctx, V);
|
||||
const size_t k_bind_end = k_bind_offset + ggml_webgpu_tensor_binding_size(ctx, K);
|
||||
const size_t v_bind_end = v_bind_offset + ggml_webgpu_tensor_binding_size(ctx, V);
|
||||
kv_bind_offset = std::min(k_bind_offset, v_bind_offset);
|
||||
kv_bind_size = std::max(k_bind_end, v_bind_end) - kv_bind_offset;
|
||||
offset_k = (uint32_t) ((ggml_webgpu_tensor_offset(K) - kv_bind_offset) / ggml_type_size(K->type));
|
||||
offset_v = (uint32_t) ((ggml_webgpu_tensor_offset(V) - kv_bind_offset) / ggml_type_size(V->type));
|
||||
}
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)),
|
||||
offset_k,
|
||||
offset_v,
|
||||
has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
|
||||
has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
@@ -1619,10 +1691,15 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
};
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V),
|
||||
};
|
||||
uint32_t binding_index = 3;
|
||||
if (kv_overlap) {
|
||||
entries.push_back(
|
||||
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size));
|
||||
} else {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V));
|
||||
}
|
||||
uint32_t binding_index = kv_overlap ? 2u : 3u;
|
||||
if (has_mask) {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask));
|
||||
}
|
||||
@@ -1638,25 +1715,25 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
shader_lib_ctx.src3 = mask;
|
||||
shader_lib_ctx.src4 = sinks;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.src_overlap = kv_overlap;
|
||||
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
|
||||
shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
|
||||
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
|
||||
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
|
||||
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
|
||||
const bool use_vec = ggml_webgpu_flash_attn_use_vec(ctx->global_ctx, Q, K, V);
|
||||
webgpu_pipeline pipeline = use_vec ? ctx->shader_lib->get_flash_attn_vec_pipeline(shader_lib_ctx) :
|
||||
ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx);
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(
|
||||
shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get());
|
||||
|
||||
if (!use_vec) {
|
||||
auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get());
|
||||
if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
|
||||
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_flash_attn_vec_decisions *>(pipeline.context.get());
|
||||
|
||||
wgpu::Buffer blk_buf = {};
|
||||
uint64_t blk_size_bytes = 0;
|
||||
uint32_t blk_nblk0 = 0;
|
||||
@@ -1695,10 +1772,12 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
tmp_bind_size = tmp_size_bytes;
|
||||
scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes);
|
||||
} else {
|
||||
// nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation.
|
||||
// nwg==1 writes final dst directly in vec-split; bind tmp to a tiny non-overlapping scratch region.
|
||||
tmp_size_bytes = WEBGPU_STORAGE_BUF_BINDING_MULT;
|
||||
tmp_buf = ggml_webgpu_tensor_buf(dst);
|
||||
tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst);
|
||||
tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst);
|
||||
tmp_bind_offset = scratch_offset;
|
||||
tmp_bind_size = tmp_size_bytes;
|
||||
scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes);
|
||||
}
|
||||
|
||||
webgpu_pipeline blk_pipeline;
|
||||
@@ -1713,7 +1792,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
|
||||
blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
const ggml_webgpu_shader_lib_context blk_shader_ctx = shader_lib_ctx;
|
||||
blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx);
|
||||
blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx, decisions->kv_tile);
|
||||
|
||||
blk_params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask
|
||||
@@ -1745,12 +1824,19 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
std::vector<wgpu::BindGroupEntry> split_entries = {
|
||||
ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q),
|
||||
ggml_webgpu_tensor_binding_size(ctx, Q)),
|
||||
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), ggml_webgpu_tensor_align_offset(ctx, K),
|
||||
ggml_webgpu_tensor_binding_size(ctx, K)),
|
||||
ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V), ggml_webgpu_tensor_align_offset(ctx, V),
|
||||
ggml_webgpu_tensor_binding_size(ctx, V)),
|
||||
};
|
||||
uint32_t split_binding_index = 3;
|
||||
if (kv_overlap) {
|
||||
split_entries.push_back(
|
||||
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size));
|
||||
} else {
|
||||
split_entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K),
|
||||
ggml_webgpu_tensor_align_offset(ctx, K),
|
||||
ggml_webgpu_tensor_binding_size(ctx, K)));
|
||||
split_entries.push_back(ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V),
|
||||
ggml_webgpu_tensor_align_offset(ctx, V),
|
||||
ggml_webgpu_tensor_binding_size(ctx, V)));
|
||||
}
|
||||
uint32_t split_binding_index = kv_overlap ? 2u : 3u;
|
||||
if (has_mask) {
|
||||
split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask),
|
||||
ggml_webgpu_tensor_align_offset(ctx, mask),
|
||||
@@ -1820,7 +1906,6 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
|
||||
return ggml_backend_webgpu_build_multi(ctx, dispatches);
|
||||
}
|
||||
#endif // __EMSCRIPTEN__
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
bool is_unary = dst->op == GGML_OP_UNARY;
|
||||
@@ -2710,11 +2795,7 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return ggml_webgpu_mul_mat_id(ctx, src0, src1, src2, node);
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
#ifndef __EMSCRIPTEN__
|
||||
return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
|
||||
#else
|
||||
return std::nullopt;
|
||||
#endif
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
@@ -2757,6 +2838,9 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
|
||||
return ggml_webgpu_solve_tri(ctx, src0, src1, node);
|
||||
case GGML_OP_SSM_CONV:
|
||||
return ggml_webgpu_ssm_conv(ctx, src0, src1, node);
|
||||
case GGML_OP_SSM_SCAN:
|
||||
return ggml_webgpu_ssm_scan(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node->src[6],
|
||||
node);
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node);
|
||||
case GGML_OP_PAD:
|
||||
@@ -2815,7 +2899,10 @@ static void ggml_backend_webgpu_collect_profile_results(webgpu_context &
|
||||
}
|
||||
#endif
|
||||
|
||||
// Don't bother checking set_rows index overflow for now, since practically the WebGPU doesn't need to support
|
||||
// models that would require it right now.
|
||||
static void ggml_backend_webgpu_check_set_rows(webgpu_context & ctx, uint32_t & num_inflight_batches) {
|
||||
#ifdef GGML_WEBGPU_CHECK_SET_ROWS
|
||||
wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder();
|
||||
encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0,
|
||||
ctx->set_rows_host_error_buf.GetSize());
|
||||
@@ -2828,6 +2915,10 @@ static void ggml_backend_webgpu_check_set_rows(webgpu_context & ctx, uint32_t &
|
||||
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
|
||||
}
|
||||
ctx->set_rows_host_error_buf.Unmap();
|
||||
#else
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(num_inflight_batches);
|
||||
#endif
|
||||
}
|
||||
|
||||
static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
||||
@@ -2913,8 +3004,6 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
|
||||
ggml_backend_webgpu_check_set_rows(ctx, num_inflight_batches);
|
||||
}
|
||||
|
||||
ggml_backend_webgpu_wait_queue(ctx->global_ctx);
|
||||
|
||||
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
@@ -3257,13 +3346,19 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
||||
ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.wg_mem_limit_bytes =
|
||||
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups;
|
||||
shader_lib_ctx.supports_subgroup_matrix =
|
||||
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix;
|
||||
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
|
||||
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
|
||||
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
|
||||
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;
|
||||
|
||||
if (ggml_webgpu_flash_attn_use_vec(ctx->webgpu_global_ctx, Q, K, V)) {
|
||||
const uint32_t kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(shader_lib_ctx);
|
||||
const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(
|
||||
shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
const uint32_t kv_tile = decisions.kv_tile;
|
||||
|
||||
const uint32_t vec_nwg_cap = std::max(
|
||||
1u, std::min<uint32_t>(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size));
|
||||
@@ -3283,6 +3378,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
||||
const size_t tmp_size_bytes = ROUNDUP_POW2(
|
||||
(tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
res += tmp_size_bytes + align;
|
||||
} else {
|
||||
res += WEBGPU_STORAGE_BUF_BINDING_MULT + align;
|
||||
}
|
||||
if (mask != nullptr) {
|
||||
const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile);
|
||||
@@ -3431,12 +3528,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
ctx->webgpu_global_ctx->capabilities.supports_subgroups =
|
||||
ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups);
|
||||
|
||||
bool valid_subgroup_matrix_config = false;
|
||||
#ifndef __EMSCRIPTEN__
|
||||
// Accept f16 subgroup matrix configurations (square or non-square).
|
||||
// NVIDIA GPUs typically report square configs (e.g. 16x16x16),
|
||||
// while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16).
|
||||
// The shaders are already parameterized to handle any M/N/K dimensions.
|
||||
bool valid_subgroup_matrix_config = false;
|
||||
if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
||||
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
|
||||
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
|
||||
@@ -3450,8 +3547,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
}
|
||||
}
|
||||
}
|
||||
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
|
||||
#endif
|
||||
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
|
||||
|
||||
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
|
||||
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
|
||||
@@ -3499,12 +3596,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
// Enable Dawn-specific toggles to increase native performance
|
||||
// TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
|
||||
// only for native performance?
|
||||
const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
|
||||
"disable_polyfills_on_integer_div_and_mod" };
|
||||
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
|
||||
const char * const deviceEnabledToggles[] = { "disable_robustness", "disable_workgroup_init",
|
||||
"disable_polyfills_on_integer_div_and_mod" };
|
||||
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
|
||||
wgpu::DawnTogglesDescriptor deviceTogglesDesc;
|
||||
deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
|
||||
deviceTogglesDesc.enabledToggleCount = 4;
|
||||
deviceTogglesDesc.enabledToggleCount = 3;
|
||||
deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
|
||||
deviceTogglesDesc.disabledToggleCount = 1;
|
||||
|
||||
@@ -3782,33 +3879,63 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
#ifndef __EMSCRIPTEN__
|
||||
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
|
||||
break;
|
||||
}
|
||||
// Head dimensions must be divisible by subgroup matrix dimensions
|
||||
if (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k != 0 ||
|
||||
src2->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_n != 0) {
|
||||
break;
|
||||
}
|
||||
// Head dimensions must fit in workgroup memory with minimum tile sizes
|
||||
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
const bool kv_direct = src1->type == GGML_TYPE_F16 &&
|
||||
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
|
||||
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
|
||||
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
|
||||
ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
|
||||
(uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
|
||||
if (min_bytes > limit_bytes) {
|
||||
break;
|
||||
}
|
||||
|
||||
supports_op = src0->type == GGML_TYPE_F32 &&
|
||||
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
|
||||
src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
|
||||
src2->type == src1->type && op->type == GGML_TYPE_F32;
|
||||
#endif
|
||||
if (!supports_op) {
|
||||
break;
|
||||
}
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src0;
|
||||
shader_lib_ctx.src1 = src1;
|
||||
shader_lib_ctx.src2 = src2;
|
||||
shader_lib_ctx.src3 = op->src[3];
|
||||
shader_lib_ctx.src4 = op->src[4];
|
||||
shader_lib_ctx.dst = const_cast<ggml_tensor *>(op);
|
||||
shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups;
|
||||
shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix;
|
||||
shader_lib_ctx.wg_mem_limit_bytes =
|
||||
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
|
||||
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
|
||||
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
|
||||
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;
|
||||
|
||||
const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(
|
||||
shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
const size_t min_bytes =
|
||||
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
|
||||
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
|
||||
if (min_bytes > limit_bytes) {
|
||||
supports_op = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||
const size_t min_bytes =
|
||||
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
|
||||
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
|
||||
if (min_bytes > limit_bytes) {
|
||||
supports_op = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
|
||||
supports_op = false;
|
||||
break;
|
||||
}
|
||||
const size_t min_bytes =
|
||||
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
|
||||
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
|
||||
if (min_bytes > limit_bytes) {
|
||||
supports_op = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case GGML_OP_RMS_NORM:
|
||||
@@ -3896,6 +4023,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
case GGML_OP_SSM_CONV:
|
||||
supports_op = op->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_OP_SSM_SCAN:
|
||||
supports_op = op->type == GGML_TYPE_F32 &&
|
||||
src0->ne[0] <= ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
break;
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
{
|
||||
const uint32_t s_v = (uint32_t) src2->ne[0];
|
||||
|
||||
@@ -138,25 +138,54 @@ struct Params {
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||
#define V K
|
||||
#else
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
|
||||
#endif
|
||||
|
||||
#if defined(MASK) && defined(SINKS)
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#elif defined(MASK)
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#elif defined(SINKS)
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
|
||||
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#else
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#endif
|
||||
#elif defined(MASK)
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
|
||||
#define DST_BINDING 3
|
||||
#define PARAMS_BINDING 4
|
||||
#else
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#endif
|
||||
#elif defined(SINKS)
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(2) var<storage, read_write> sinks: array<f32>;
|
||||
#define DST_BINDING 3
|
||||
#define PARAMS_BINDING 4
|
||||
#else
|
||||
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#endif
|
||||
#else
|
||||
#ifdef KV_OVERLAP
|
||||
#define DST_BINDING 2
|
||||
#define PARAMS_BINDING 3
|
||||
#else
|
||||
#define DST_BINDING 3
|
||||
#define PARAMS_BINDING 4
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
|
||||
|
||||
330
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl
Normal file
330
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl
Normal file
@@ -0,0 +1,330 @@
|
||||
enable f16;
|
||||
enable subgroups;
|
||||
|
||||
#define HEAD_DIM_QK 64
|
||||
#define HEAD_DIM_V 64
|
||||
#define KV_STAGE_STRIDE 64
|
||||
#define Q_TILE 4
|
||||
#define KV_TILE 64
|
||||
#define WG_SIZE 128
|
||||
|
||||
struct Params {
|
||||
offset_q: u32,
|
||||
offset_k: u32,
|
||||
offset_v: u32,
|
||||
offset_mask: u32,
|
||||
offset_sinks: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
n_heads: u32,
|
||||
seq_len_q: u32,
|
||||
seq_len_kv: u32,
|
||||
|
||||
stride_q1: u32,
|
||||
stride_q2: u32,
|
||||
stride_q3: u32,
|
||||
stride_k1: u32,
|
||||
stride_k2: u32,
|
||||
stride_k3: u32,
|
||||
stride_v1: u32,
|
||||
stride_v2: u32,
|
||||
stride_v3: u32,
|
||||
stride_mask3: u32,
|
||||
|
||||
q_per_kv: u32,
|
||||
|
||||
scale: f32,
|
||||
max_bias: f32,
|
||||
logit_softcap: f32,
|
||||
n_head_log2: f32,
|
||||
m0: f32,
|
||||
m1: f32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<f16>>;
|
||||
#define V K
|
||||
#else
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<f16>>;
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<vec4<f16>>;
|
||||
#endif
|
||||
|
||||
#if defined(MASK) && defined(SINKS)
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
|
||||
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#else
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#endif
|
||||
#elif defined(MASK)
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
|
||||
#define DST_BINDING 3
|
||||
#define PARAMS_BINDING 4
|
||||
#else
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#endif
|
||||
#elif defined(SINKS)
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(2) var<storage, read_write> sinks: array<f32>;
|
||||
#define DST_BINDING 3
|
||||
#define PARAMS_BINDING 4
|
||||
#else
|
||||
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#endif
|
||||
#else
|
||||
#ifdef KV_OVERLAP
|
||||
#define DST_BINDING 2
|
||||
#define PARAMS_BINDING 3
|
||||
#else
|
||||
#define DST_BINDING 3
|
||||
#define PARAMS_BINDING 4
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
|
||||
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
|
||||
|
||||
const FLOAT_MIN: f32 = -1.0e9;
|
||||
const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u;
|
||||
const V_CHUNKS: u32 = HEAD_DIM_V / 4u;
|
||||
const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE;
|
||||
const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE;
|
||||
|
||||
var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
|
||||
var<workgroup> kv_shmem: array<f16, KV_TILE * KV_STAGE_STRIDE>;
|
||||
var<workgroup> p_shmem: array<f32, Q_TILE * KV_TILE>;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(subgroup_id) subgroup_id: u32,
|
||||
@builtin(subgroup_size) subgroup_size: u32,
|
||||
@builtin(num_subgroups) num_subgroups: u32,
|
||||
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
|
||||
if (subgroup_size == 0u || num_subgroups < Q_TILE) {
|
||||
return;
|
||||
}
|
||||
|
||||
let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
|
||||
let wg_per_batch = wg_per_head * params.n_heads;
|
||||
|
||||
let dst2_stride = HEAD_DIM_V * params.n_heads;
|
||||
let dst3_stride = dst2_stride * params.seq_len_q;
|
||||
|
||||
let batch_idx = wg_id.x / wg_per_batch;
|
||||
let q_batch_offset = params.offset_q + batch_idx * params.stride_q3;
|
||||
let k_batch_offset = params.offset_k + batch_idx * params.stride_k3;
|
||||
let v_batch_offset = params.offset_v + batch_idx * params.stride_v3;
|
||||
let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride;
|
||||
let wg_in_batch = wg_id.x % wg_per_batch;
|
||||
|
||||
let head_idx = wg_in_batch / wg_per_head;
|
||||
let q_head_offset = q_batch_offset + head_idx * params.stride_q2;
|
||||
let k_head_idx = head_idx / params.q_per_kv;
|
||||
let v_head_offset = v_batch_offset + k_head_idx * params.stride_v2;
|
||||
let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
|
||||
|
||||
let wg_in_head = wg_in_batch % wg_per_head;
|
||||
let q_row_start = wg_in_head * Q_TILE;
|
||||
let global_q_row = q_row_start + subgroup_id;
|
||||
let row_active = subgroup_id < Q_TILE && global_q_row < params.seq_len_q;
|
||||
|
||||
#ifdef MASK
|
||||
let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv;
|
||||
#endif
|
||||
|
||||
let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V;
|
||||
|
||||
let head = f32(head_idx);
|
||||
let slope = select(1.0,
|
||||
select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0),
|
||||
pow(params.m0, head + 1.0),
|
||||
head < params.n_head_log2),
|
||||
params.max_bias > 0.0);
|
||||
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
|
||||
let q_tile_row = elem_idx / HEAD_DIM_QK;
|
||||
let q_col = elem_idx % HEAD_DIM_QK;
|
||||
let head_q_row = q_row_start + q_tile_row;
|
||||
let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
|
||||
q_shmem[elem_idx] = f16(select(
|
||||
0.0,
|
||||
Q[global_q_row_offset + q_col] * params.scale,
|
||||
head_q_row < params.seq_len_q));
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
var row_max = FLOAT_MIN;
|
||||
var exp_sum = 0.0;
|
||||
var out_regs: array<vec4<f32>, OUT_REGS_PER_LANE>;
|
||||
for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) {
|
||||
out_regs[reg_idx] = vec4<f32>(0.0);
|
||||
}
|
||||
|
||||
let q_base = subgroup_id * HEAD_DIM_QK;
|
||||
let subgroup_p_offset = subgroup_id * KV_TILE;
|
||||
|
||||
for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) {
|
||||
let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile);
|
||||
let score_slots = min(SCORE_REGS_PER_LANE, (kv_count + subgroup_size - 1u) / subgroup_size);
|
||||
let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size);
|
||||
var local_scores: array<f32, SCORE_REGS_PER_LANE>;
|
||||
for (var slot = 0u; slot < SCORE_REGS_PER_LANE; slot += 1u) {
|
||||
local_scores[slot] = FLOAT_MIN;
|
||||
}
|
||||
|
||||
for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) {
|
||||
let kv_local = vec_idx_local / Q_CHUNKS;
|
||||
let chunk = vec_idx_local % Q_CHUNKS;
|
||||
let global_k_row = kv_tile + kv_local;
|
||||
let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u;
|
||||
let k4 = K[k_vec_index];
|
||||
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||
kv_shmem[kv_off + 0u] = k4.x;
|
||||
kv_shmem[kv_off + 1u] = k4.y;
|
||||
kv_shmem[kv_off + 2u] = k4.z;
|
||||
kv_shmem[kv_off + 3u] = k4.w;
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
var local_max = FLOAT_MIN;
|
||||
if (row_active) {
|
||||
for (var slot = 0u; slot < score_slots; slot += 1u) {
|
||||
let kv_local = sg_inv_id + slot * subgroup_size;
|
||||
if (kv_local >= kv_count) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let global_k_row = kv_tile + kv_local;
|
||||
var dot_val = 0.0;
|
||||
for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) {
|
||||
let q_off = q_base + chunk * 4u;
|
||||
let qv = vec4<f32>(
|
||||
f32(q_shmem[q_off + 0u]),
|
||||
f32(q_shmem[q_off + 1u]),
|
||||
f32(q_shmem[q_off + 2u]),
|
||||
f32(q_shmem[q_off + 3u]));
|
||||
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||
let kv = vec4<f32>(
|
||||
f32(kv_shmem[kv_off + 0u]),
|
||||
f32(kv_shmem[kv_off + 1u]),
|
||||
f32(kv_shmem[kv_off + 2u]),
|
||||
f32(kv_shmem[kv_off + 3u]));
|
||||
dot_val += dot(qv, kv);
|
||||
}
|
||||
#ifdef LOGIT_SOFTCAP
|
||||
dot_val = params.logit_softcap * tanh(dot_val);
|
||||
#endif
|
||||
#ifdef MASK
|
||||
let mask_idx = mask_global_offset + subgroup_id * params.seq_len_kv + global_k_row;
|
||||
dot_val += slope * f32(mask[mask_idx]);
|
||||
#endif
|
||||
local_scores[slot] = dot_val;
|
||||
local_max = max(local_max, dot_val);
|
||||
}
|
||||
}
|
||||
|
||||
let tile_max = subgroupMax(local_max);
|
||||
let new_max = max(row_max, tile_max);
|
||||
let cur_exp = exp(row_max - new_max);
|
||||
exp_sum *= cur_exp;
|
||||
for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) {
|
||||
out_regs[reg_idx] *= cur_exp;
|
||||
}
|
||||
|
||||
var local_sum = 0.0;
|
||||
for (var slot = 0u; slot < score_slots; slot += 1u) {
|
||||
let kv_local = sg_inv_id + slot * subgroup_size;
|
||||
if (row_active && kv_local < kv_count) {
|
||||
let p = exp(local_scores[slot] - new_max);
|
||||
p_shmem[subgroup_p_offset + kv_local] = p;
|
||||
local_sum += p;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) {
|
||||
let kv_local = vec_idx_local / V_CHUNKS;
|
||||
let chunk = vec_idx_local % V_CHUNKS;
|
||||
let global_v_row = kv_tile + kv_local;
|
||||
let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u;
|
||||
let v4 = V[v_vec_index];
|
||||
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||
kv_shmem[kv_off + 0u] = v4.x;
|
||||
kv_shmem[kv_off + 1u] = v4.y;
|
||||
kv_shmem[kv_off + 2u] = v4.z;
|
||||
kv_shmem[kv_off + 3u] = v4.w;
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
let tile_sum = subgroupAdd(local_sum);
|
||||
exp_sum += tile_sum;
|
||||
row_max = new_max;
|
||||
|
||||
if (row_active) {
|
||||
for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) {
|
||||
let chunk = sg_inv_id + reg_idx * subgroup_size;
|
||||
if (chunk >= V_CHUNKS) {
|
||||
continue;
|
||||
}
|
||||
|
||||
var acc = out_regs[reg_idx];
|
||||
for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) {
|
||||
let p = p_shmem[subgroup_p_offset + kv_local];
|
||||
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||
let v4 = vec4<f32>(
|
||||
f32(kv_shmem[kv_off + 0u]),
|
||||
f32(kv_shmem[kv_off + 1u]),
|
||||
f32(kv_shmem[kv_off + 2u]),
|
||||
f32(kv_shmem[kv_off + 3u]));
|
||||
acc += p * v4;
|
||||
}
|
||||
out_regs[reg_idx] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
#ifdef SINKS
|
||||
if (row_active) {
|
||||
let sink_score = sinks[params.offset_sinks + head_idx];
|
||||
let sink_max = max(row_max, sink_score);
|
||||
let sink_scale = exp(row_max - sink_max);
|
||||
for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) {
|
||||
out_regs[reg_idx] *= sink_scale;
|
||||
}
|
||||
exp_sum = exp_sum * sink_scale + exp(sink_score - sink_max);
|
||||
row_max = sink_max;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (row_active) {
|
||||
let inv_exp_sum = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
|
||||
let row_base = dst_global_offset + subgroup_id * dst2_stride;
|
||||
let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size);
|
||||
for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) {
|
||||
let chunk = sg_inv_id + reg_idx * subgroup_size;
|
||||
if (chunk >= V_CHUNKS) {
|
||||
continue;
|
||||
}
|
||||
let dst_vec_index = (row_base + chunk * 4u) >> 2u;
|
||||
dst[dst_vec_index] = out_regs[reg_idx] * inv_exp_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,7 @@ struct Params {
|
||||
nblk1: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read> mask: array<f16>;
|
||||
@group(0) @binding(0) var<storage, read_write> mask: array<f16>;
|
||||
@group(0) @binding(1) var<storage, read_write> blk: array<u32>;
|
||||
@group(0) @binding(2) var<uniform> params: Params;
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
diagnostic(off, chromium.subgroup_matrix_uniformity);
|
||||
diagnostic(off, subgroup_uniformity);
|
||||
enable f16;
|
||||
enable subgroups;
|
||||
enable chromium_experimental_subgroup_matrix;
|
||||
|
||||
#ifdef KV_F32
|
||||
#define KV_TYPE f32
|
||||
@@ -13,19 +11,14 @@ enable chromium_experimental_subgroup_matrix;
|
||||
#define HEAD_DIM_QK 64
|
||||
#define HEAD_DIM_V 64
|
||||
|
||||
|
||||
#define SG_MAT_M 8
|
||||
#define SG_MAT_N 8
|
||||
#define SG_MAT_K 8
|
||||
|
||||
#define Q_TILE SG_MAT_M
|
||||
#define KV_GRANULARITY 8
|
||||
#define KV_TILE 16
|
||||
#define WG_SIZE 64
|
||||
#ifndef VEC_NE
|
||||
#define VEC_NE 4u
|
||||
#endif
|
||||
|
||||
#define KV_BLOCKS (KV_TILE / SG_MAT_N)
|
||||
#define KV_BLOCKS (KV_TILE / KV_GRANULARITY)
|
||||
|
||||
#define BLOCK_SIZE 32
|
||||
#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
|
||||
@@ -97,6 +90,14 @@ struct Params {
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
|
||||
#ifdef KV_OVERLAP
|
||||
#if defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||
#else
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
|
||||
#endif
|
||||
#define V K
|
||||
#else
|
||||
#if defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||
#else
|
||||
@@ -107,7 +108,22 @@ struct Params {
|
||||
#else
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>;
|
||||
#endif
|
||||
#endif
|
||||
#if defined(MASK) && defined(SINKS)
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
|
||||
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||
#ifdef BLK
|
||||
#define BLK_BINDING 4
|
||||
#define TMP_BINDING 5
|
||||
#define DST_BINDING 6
|
||||
#define PARAMS_BINDING 7
|
||||
#else
|
||||
#define TMP_BINDING 4
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#endif
|
||||
#else
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
|
||||
#ifdef BLK
|
||||
@@ -120,7 +136,21 @@ struct Params {
|
||||
#define DST_BINDING 6
|
||||
#define PARAMS_BINDING 7
|
||||
#endif
|
||||
#endif
|
||||
#elif defined(MASK)
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
|
||||
#ifdef BLK
|
||||
#define BLK_BINDING 3
|
||||
#define TMP_BINDING 4
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#else
|
||||
#define TMP_BINDING 3
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#endif
|
||||
#else
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
#ifdef BLK
|
||||
#define BLK_BINDING 4
|
||||
@@ -132,16 +162,30 @@ struct Params {
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#endif
|
||||
#endif
|
||||
#elif defined(SINKS)
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(2) var<storage, read_write> sinks: array<f32>;
|
||||
#define TMP_BINDING 3
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#else
|
||||
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||
#define TMP_BINDING 4
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#endif
|
||||
#else
|
||||
#ifdef KV_OVERLAP
|
||||
#define TMP_BINDING 2
|
||||
#define DST_BINDING 3
|
||||
#define PARAMS_BINDING 4
|
||||
#else
|
||||
#define TMP_BINDING 3
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef BLK
|
||||
@group(0) @binding(BLK_BINDING) var<storage, read_write> blk: array<u32>;
|
||||
@@ -153,7 +197,7 @@ struct Params {
|
||||
// Just a very small float value.
|
||||
const FLOAT_MIN: f32 = -1.0e9;
|
||||
|
||||
var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
|
||||
var<workgroup> q_shmem: array<f16, HEAD_DIM_QK>;
|
||||
|
||||
#ifndef KV_DIRECT
|
||||
const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
|
||||
@@ -161,31 +205,27 @@ const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
|
||||
var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
|
||||
#endif
|
||||
|
||||
var<workgroup> o_shmem: array<f16, Q_TILE * HEAD_DIM_V>;
|
||||
var<workgroup> o_shmem: array<f16, HEAD_DIM_V>;
|
||||
|
||||
#ifdef MASK
|
||||
// storage for mask values
|
||||
var<workgroup> mask_shmem: array<f16, Q_TILE * KV_TILE>;
|
||||
var<workgroup> mask_shmem: array<f16, KV_TILE>;
|
||||
#endif
|
||||
|
||||
// note that we reuse the same storage for both since we only need one at a time
|
||||
var<workgroup> inter_shmem: array<f16, Q_TILE * KV_TILE>;
|
||||
var<workgroup> inter_shmem: array<f16, KV_TILE>;
|
||||
|
||||
// Storage for row max and exp sum during online softmax
|
||||
var<workgroup> row_max_shmem: array<f32, Q_TILE>;
|
||||
var<workgroup> exp_sum_shmem: array<f32, Q_TILE>;
|
||||
var<workgroup> blk_state_wg: u32;
|
||||
|
||||
fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 {
|
||||
fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 {
|
||||
var v = select(FLOAT_MIN,
|
||||
f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale,
|
||||
f32(inter_shmem[kv_idx]) * params.scale,
|
||||
kv_idx < KV_TILE);
|
||||
#ifdef LOGIT_SOFTCAP
|
||||
v = params.logit_softcap * tanh(v);
|
||||
#endif
|
||||
#ifdef MASK
|
||||
if (apply_mask) {
|
||||
var mask_val = select(0.0,f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE);
|
||||
var mask_val = select(0.0, f32(mask_shmem[kv_idx]), kv_idx < KV_TILE);
|
||||
v += select(mask_val, slope * mask_val, has_bias);
|
||||
}
|
||||
#endif
|
||||
@@ -199,19 +239,17 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(subgroup_size) subgroup_size: u32,
|
||||
@builtin(num_subgroups) num_subgroups: u32,
|
||||
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
|
||||
// Vec path processes exactly one query row per workgroup, so subgroup 0 can
|
||||
// keep the running softmax state in private storage.
|
||||
var row_max = FLOAT_MIN;
|
||||
var exp_sum = 0.0;
|
||||
|
||||
// initialize row max for online softmax
|
||||
for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
|
||||
row_max_shmem[i] = FLOAT_MIN;
|
||||
exp_sum_shmem[i] = 0.0;
|
||||
}
|
||||
|
||||
for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) {
|
||||
for (var i = local_id.x; i < HEAD_DIM_V; i += WG_SIZE) {
|
||||
o_shmem[i] = 0.0;
|
||||
}
|
||||
|
||||
// workgroups per head/batch
|
||||
let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
|
||||
let wg_per_head = params.seq_len_q;
|
||||
let wg_per_batch = wg_per_head * params.n_heads;
|
||||
|
||||
let dst2_stride = HEAD_DIM_V * params.n_heads;
|
||||
@@ -235,9 +273,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
|
||||
let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2;
|
||||
|
||||
// starting Q row for this workgroup
|
||||
// Vec path handles one Q row per workgroup.
|
||||
let wg_in_head = wg_in_batch % wg_per_head;
|
||||
let q_row_start = wg_in_head * Q_TILE;
|
||||
let q_row_start = wg_in_head;
|
||||
|
||||
#ifdef MASK
|
||||
// mask offset
|
||||
@@ -248,21 +286,18 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
let has_bias = params.max_bias > 0.0;
|
||||
let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias);
|
||||
|
||||
// load q tile into shared memory
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
|
||||
let q_row = elem_idx / HEAD_DIM_QK;
|
||||
let q_col = elem_idx % HEAD_DIM_QK;
|
||||
let head_q_row = q_row_start + q_row;
|
||||
let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
|
||||
// load the single Q row into shared memory
|
||||
for (var elem_idx = local_id.x; elem_idx < HEAD_DIM_QK; elem_idx += WG_SIZE) {
|
||||
let global_q_row_offset = q_head_offset + q_row_start * params.stride_q1;
|
||||
q_shmem[elem_idx] = f16(select(
|
||||
0.0,
|
||||
Q[global_q_row_offset + q_col],
|
||||
head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK));
|
||||
Q[global_q_row_offset + elem_idx],
|
||||
q_row_start < params.seq_len_q));
|
||||
}
|
||||
|
||||
for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) {
|
||||
#ifdef BLK
|
||||
let q_blk = q_row_start / Q_TILE;
|
||||
let q_blk = q_row_start;
|
||||
let kv_blk = kv_tile / KV_TILE;
|
||||
let blk_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
|
||||
let blk_idx = params.blk_base + (blk_batch * params.blk_nblk1 + q_blk) * params.blk_nblk0 + kv_blk;
|
||||
@@ -270,13 +305,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
#else
|
||||
let blk_state_local = 1u;
|
||||
#endif
|
||||
if (local_id.x == 0u) {
|
||||
blk_state_wg = blk_state_local;
|
||||
}
|
||||
workgroupBarrier();
|
||||
let blk_state = blk_state_wg;
|
||||
let blk_state = blk_state_local;
|
||||
let skip_tile = blk_state == 0u;
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
|
||||
for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) {
|
||||
inter_shmem[elem_idx] = f16(0.0);
|
||||
}
|
||||
|
||||
@@ -360,20 +391,14 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
let num_of_threads = subgroup_size / VEC_NE;
|
||||
let tx = sg_inv_id % num_of_threads;
|
||||
let ty = sg_inv_id / num_of_threads;
|
||||
for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) {
|
||||
continue;
|
||||
}
|
||||
let local_q_row_offset = q_tile_row * HEAD_DIM_QK;
|
||||
|
||||
if (subgroup_id == 0u && q_row_start < params.seq_len_q) {
|
||||
for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += VEC_NE) {
|
||||
let kv_idx = kv_base + ty;
|
||||
var partial_sum: f32 = 0.0;
|
||||
let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv;
|
||||
if (kv_valid) {
|
||||
for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) {
|
||||
let q_off = local_q_row_offset + i * 4u;
|
||||
let q_off = i * 4u;
|
||||
|
||||
let qv = vec4<f32>(
|
||||
f32(q_shmem[q_off + 0u]),
|
||||
@@ -410,8 +435,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
|
||||
let sum_bcast = subgroupShuffle(sum, num_of_threads * ty);
|
||||
if (tx == 0u && kv_valid) {
|
||||
let dst_idx = q_tile_row * KV_TILE + kv_idx;
|
||||
inter_shmem[dst_idx] = f16(sum_bcast);
|
||||
inter_shmem[kv_idx] = f16(sum_bcast);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -422,13 +446,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
let apply_mask = !skip_tile && (blk_state != 2u);
|
||||
if (apply_mask) {
|
||||
// load mask tile into shared memory for this KV block
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
|
||||
let mask_row = elem_idx / KV_TILE;
|
||||
let mask_col = elem_idx % KV_TILE;
|
||||
let global_q_row = q_row_start + mask_row;
|
||||
let global_k_col = kv_tile + mask_col;
|
||||
let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv;
|
||||
let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col;
|
||||
for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) {
|
||||
let global_k_col = kv_tile + elem_idx;
|
||||
let mask_in_bounds = q_row_start < params.seq_len_q && global_k_col < params.seq_len_kv;
|
||||
let mask_idx = mask_global_offset + global_k_col;
|
||||
mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds);
|
||||
}
|
||||
}
|
||||
@@ -439,50 +460,40 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
workgroupBarrier();
|
||||
|
||||
// online softmax
|
||||
if (!skip_tile) {
|
||||
for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) {
|
||||
break;
|
||||
}
|
||||
if (!skip_tile && subgroup_id == 0u && q_row_start < params.seq_len_q) {
|
||||
var prev_max = row_max;
|
||||
var final_max = prev_max;
|
||||
// pass 1: compute final max across the full KV tile in chunks
|
||||
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
|
||||
let kv_idx = kv_offset + sg_inv_id;
|
||||
let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE;
|
||||
let softmax_term = select(FLOAT_MIN,
|
||||
calc_softmax_term(kv_idx, slope, has_bias, apply_mask),
|
||||
kv_valid);
|
||||
final_max = subgroupMax(max(final_max, softmax_term));
|
||||
}
|
||||
|
||||
var prev_max = row_max_shmem[q_tile_row];
|
||||
var final_max = prev_max;
|
||||
// pass 1: compute final max across the full KV tile in chunks
|
||||
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
|
||||
let kv_idx = kv_offset + sg_inv_id;
|
||||
let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE;
|
||||
let softmax_term = select(FLOAT_MIN,
|
||||
calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask),
|
||||
kv_valid);
|
||||
final_max = subgroupMax(max(final_max, softmax_term));
|
||||
var total_exp_term: f32 = 0.0;
|
||||
// pass 2: compute exp sum and write P using final_max
|
||||
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
|
||||
let kv_idx = kv_offset + sg_inv_id;
|
||||
let softmax_term = calc_softmax_term(kv_idx, slope, has_bias, apply_mask);
|
||||
let cur_p = select(0.0,
|
||||
exp(softmax_term - final_max),
|
||||
kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
|
||||
total_exp_term += subgroupAdd(cur_p);
|
||||
if (kv_idx < KV_TILE) {
|
||||
inter_shmem[kv_idx] = f16(cur_p);
|
||||
}
|
||||
}
|
||||
|
||||
var total_exp_term: f32 = 0.0;
|
||||
// pass 2: compute exp sum and write P using final_max
|
||||
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
|
||||
let kv_idx = kv_offset + sg_inv_id;
|
||||
let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask);
|
||||
let cur_p = select(0.0,
|
||||
exp(softmax_term - final_max),
|
||||
kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
|
||||
total_exp_term += subgroupAdd(cur_p);
|
||||
if (kv_idx < KV_TILE) {
|
||||
inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p);
|
||||
}
|
||||
}
|
||||
let cur_exp = exp(prev_max - final_max);
|
||||
|
||||
let cur_exp = exp(prev_max - final_max);
|
||||
row_max = final_max;
|
||||
exp_sum = exp_sum * cur_exp + total_exp_term;
|
||||
|
||||
if (sg_inv_id == 0) {
|
||||
row_max_shmem[q_tile_row] = final_max;
|
||||
exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term;
|
||||
}
|
||||
|
||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||
let idx = q_tile_row * HEAD_DIM_V + elem_idx;
|
||||
o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp);
|
||||
}
|
||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||
o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * cur_exp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -562,15 +573,13 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
workgroupBarrier();
|
||||
|
||||
if (!skip_tile) {
|
||||
// we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem
|
||||
// we have P (KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem
|
||||
// we want to compute O += P * V across the full KV tile
|
||||
let ne_threads : u32 = VEC_NE;
|
||||
let nl_threads = max(1u, subgroup_size / ne_threads);
|
||||
let tx_pv = sg_inv_id % nl_threads;
|
||||
let ty_pv = sg_inv_id / nl_threads;
|
||||
for (var q_tile_row = subgroup_id;
|
||||
q_tile_row < Q_TILE;
|
||||
q_tile_row += num_subgroups) {
|
||||
if (subgroup_id == 0u && q_row_start < params.seq_len_q) {
|
||||
for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_threads) {
|
||||
var lo = vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
||||
for (var cc = 0u; cc < KV_TILE / ne_threads; cc += 1u) {
|
||||
@@ -580,7 +589,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
continue;
|
||||
}
|
||||
|
||||
let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]);
|
||||
let p = f32(inter_shmem[kv_idx]);
|
||||
#ifdef KV_DIRECT
|
||||
let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u;
|
||||
let v4 = vec4<f32>(V[v_idx >> 2u]);
|
||||
@@ -621,11 +630,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
|
||||
if (ty_pv == 0u) {
|
||||
let elem_base = vec_col * 4u;
|
||||
let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base;
|
||||
o_shmem[o_base_idx + 0u] = f16(f32(o_shmem[o_base_idx + 0u]) + lo_x);
|
||||
o_shmem[o_base_idx + 1u] = f16(f32(o_shmem[o_base_idx + 1u]) + lo_y);
|
||||
o_shmem[o_base_idx + 2u] = f16(f32(o_shmem[o_base_idx + 2u]) + lo_z);
|
||||
o_shmem[o_base_idx + 3u] = f16(f32(o_shmem[o_base_idx + 3u]) + lo_w);
|
||||
o_shmem[elem_base + 0u] = f16(f32(o_shmem[elem_base + 0u]) + lo_x);
|
||||
o_shmem[elem_base + 1u] = f16(f32(o_shmem[elem_base + 1u]) + lo_y);
|
||||
o_shmem[elem_base + 2u] = f16(f32(o_shmem[elem_base + 2u]) + lo_z);
|
||||
o_shmem[elem_base + 3u] = f16(f32(o_shmem[elem_base + 3u]) + lo_w);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -637,70 +645,46 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
|
||||
#ifdef SINKS
|
||||
// Sinks are global terms and must be applied exactly once across split workgroups.
|
||||
if (iwg == 0u) {
|
||||
for (var q_tile_row = subgroup_id;
|
||||
q_tile_row < Q_TILE;
|
||||
q_tile_row += num_subgroups) {
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) {
|
||||
break;
|
||||
}
|
||||
if (iwg == 0u && subgroup_id == 0u && q_row_start < params.seq_len_q) {
|
||||
var prev_max = row_max;
|
||||
|
||||
var prev_max = row_max_shmem[q_tile_row];
|
||||
// for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum
|
||||
let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0u);
|
||||
let new_max = subgroupMax(max(prev_max, sink_val));
|
||||
let max_exp = exp(prev_max - new_max);
|
||||
let sink_exp = exp(sink_val - new_max);
|
||||
|
||||
// for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum
|
||||
let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0);
|
||||
let new_max = subgroupMax(max(prev_max, sink_val));
|
||||
let max_exp = exp(prev_max - new_max);
|
||||
let sink_exp = exp(sink_val - new_max);
|
||||
let sink_exp_sum = subgroupAdd(sink_exp);
|
||||
|
||||
let sink_exp_sum = subgroupAdd(sink_exp);
|
||||
row_max = new_max;
|
||||
exp_sum = exp_sum * max_exp + sink_exp_sum;
|
||||
|
||||
if (sg_inv_id == 0) {
|
||||
row_max_shmem[q_tile_row] = new_max;
|
||||
exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum;
|
||||
}
|
||||
|
||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||
let idx = q_tile_row * HEAD_DIM_V + elem_idx;
|
||||
o_shmem[idx] = f16(f32(o_shmem[idx]) * max_exp);
|
||||
}
|
||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||
o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * max_exp);
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
workgroupBarrier();
|
||||
#endif
|
||||
let rows_per_batch = params.n_heads * params.seq_len_q;
|
||||
for (var q_tile_row = subgroup_id;
|
||||
q_tile_row < Q_TILE;
|
||||
q_tile_row += num_subgroups) {
|
||||
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) { break; }
|
||||
|
||||
if (subgroup_id == 0u && q_row_start < params.seq_len_q) {
|
||||
if (params.nwg == 1u) {
|
||||
let exp_sum = exp_sum_shmem[q_tile_row];
|
||||
let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
|
||||
let row_base: u32 =
|
||||
params.offset_dst + batch_idx * dst3_stride + global_q_row * dst2_stride + head_idx * HEAD_DIM_V;
|
||||
let row_base: u32 = params.offset_dst + batch_idx * dst3_stride + q_row_start * dst2_stride +
|
||||
head_idx * HEAD_DIM_V;
|
||||
|
||||
for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) {
|
||||
let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
|
||||
let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
|
||||
let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
|
||||
let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
|
||||
|
||||
let v = vec4<f32>(
|
||||
f32(o_shmem[i0]) * scale,
|
||||
f32(o_shmem[i1]) * scale,
|
||||
f32(o_shmem[i2]) * scale,
|
||||
f32(o_shmem[i3]) * scale
|
||||
f32(o_shmem[elem_base + 0u]) * scale,
|
||||
f32(o_shmem[elem_base + 1u]) * scale,
|
||||
f32(o_shmem[elem_base + 2u]) * scale,
|
||||
f32(o_shmem[elem_base + 3u]) * scale
|
||||
);
|
||||
|
||||
let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
|
||||
dst[dst_vec_index] = v;
|
||||
}
|
||||
} else {
|
||||
let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + global_q_row;
|
||||
let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + q_row_start;
|
||||
let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V;
|
||||
let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg;
|
||||
|
||||
@@ -708,21 +692,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
elem_base < HEAD_DIM_V;
|
||||
elem_base += subgroup_size * 4u) {
|
||||
|
||||
let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
|
||||
let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
|
||||
let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
|
||||
let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
|
||||
|
||||
let tbase = tmp_row_data_base + elem_base;
|
||||
tmp[tbase + 0u] = f32(o_shmem[i0]);
|
||||
tmp[tbase + 1u] = f32(o_shmem[i1]);
|
||||
tmp[tbase + 2u] = f32(o_shmem[i2]);
|
||||
tmp[tbase + 3u] = f32(o_shmem[i3]);
|
||||
tmp[tbase + 0u] = f32(o_shmem[elem_base + 0u]);
|
||||
tmp[tbase + 1u] = f32(o_shmem[elem_base + 1u]);
|
||||
tmp[tbase + 2u] = f32(o_shmem[elem_base + 2u]);
|
||||
tmp[tbase + 3u] = f32(o_shmem[elem_base + 3u]);
|
||||
}
|
||||
|
||||
if (sg_inv_id == 0u) {
|
||||
tmp[tmp_row_stats_base + 0u] = exp_sum_shmem[q_tile_row];
|
||||
tmp[tmp_row_stats_base + 1u] = row_max_shmem[q_tile_row];
|
||||
tmp[tmp_row_stats_base + 0u] = exp_sum;
|
||||
tmp[tmp_row_stats_base + 1u] = row_max;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
168
ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl
Normal file
168
ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl
Normal file
@@ -0,0 +1,168 @@
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
enable subgroups;
|
||||
#endif
|
||||
|
||||
struct Params {
|
||||
offset_s: u32,
|
||||
offset_x: u32,
|
||||
offset_dt: u32,
|
||||
offset_A: u32,
|
||||
offset_B: u32,
|
||||
offset_C: u32,
|
||||
offset_ids: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
stride_s1: u32,
|
||||
stride_s2: u32,
|
||||
stride_s3: u32,
|
||||
|
||||
stride_x1: u32,
|
||||
stride_x2: u32,
|
||||
stride_x3: u32,
|
||||
|
||||
stride_dt1: u32,
|
||||
stride_dt2: u32,
|
||||
|
||||
a_ne0: u32,
|
||||
stride_A1: u32,
|
||||
|
||||
stride_B1: u32,
|
||||
stride_B2: u32,
|
||||
stride_B3: u32,
|
||||
|
||||
stride_C1: u32,
|
||||
stride_C2: u32,
|
||||
stride_C3: u32,
|
||||
|
||||
d_state: u32,
|
||||
d_inner: u32,
|
||||
n_head: u32,
|
||||
n_group: u32,
|
||||
n_seq_tokens: u32,
|
||||
n_seqs: u32,
|
||||
|
||||
y_elems: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s_in: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read_write> x: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> dt: array<f32>;
|
||||
@group(0) @binding(3) var<storage, read_write> A: array<f32>;
|
||||
@group(0) @binding(4) var<storage, read_write> B: array<f32>;
|
||||
@group(0) @binding(5) var<storage, read_write> C: array<f32>;
|
||||
@group(0) @binding(6) var<storage, read_write> ids: array<i32>;
|
||||
@group(0) @binding(7) var<storage, read_write> dst: array<f32>;
|
||||
@group(0) @binding(8) var<uniform> params: Params;
|
||||
|
||||
var<workgroup> shared_x_dt: array<f32, TOKENS_PER_TILE>;
|
||||
var<workgroup> shared_dtsp: array<f32, TOKENS_PER_TILE>;
|
||||
var<workgroup> shared_reduce: array<f32, TOKENS_PER_TILE * WG_SIZE>;
|
||||
|
||||
fn reduce_base(token_in_tile: u32) -> u32 {
|
||||
return token_in_tile * WG_SIZE;
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
, @builtin(subgroup_id) subgroup_id: u32,
|
||||
@builtin(subgroup_invocation_id) subgroup_invocation_id: u32,
|
||||
@builtin(num_subgroups) num_subgroups: u32
|
||||
#endif
|
||||
) {
|
||||
let tid = local_id.x;
|
||||
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
|
||||
|
||||
let i1 = wg_linear % params.d_inner;
|
||||
let head_seq = wg_linear / params.d_inner;
|
||||
let ir = head_seq % params.n_head;
|
||||
let i3 = head_seq / params.n_head;
|
||||
|
||||
let state_slot = u32(ids[params.offset_ids + i3]);
|
||||
let g = ir / (params.n_head / params.n_group);
|
||||
|
||||
let s_idx = params.offset_s + tid + i1 * params.stride_s1 + ir * params.stride_s2 + state_slot * params.stride_s3;
|
||||
var s_prev = s_in[s_idx];
|
||||
|
||||
let A0 = A[params.offset_A + (tid % params.a_ne0) + ir * params.stride_A1];
|
||||
|
||||
for (var token_base = 0u; token_base < params.n_seq_tokens; token_base += TOKENS_PER_TILE) {
|
||||
if (tid < TOKENS_PER_TILE) {
|
||||
let token = token_base + tid;
|
||||
if (token < params.n_seq_tokens) {
|
||||
let x_idx = params.offset_x + i1 + ir * params.stride_x1 + token * params.stride_x2 + i3 * params.stride_x3;
|
||||
let dt_idx = params.offset_dt + ir + token * params.stride_dt1 + i3 * params.stride_dt2;
|
||||
let dt0 = dt[dt_idx];
|
||||
let dtsp = select(log(1.0 + exp(dt0)), dt0, dt0 > 20.0);
|
||||
shared_dtsp[tid] = dtsp;
|
||||
shared_x_dt[tid] = x[x_idx] * dtsp;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
for (var token_in_tile = 0u; token_in_tile < TOKENS_PER_TILE; token_in_tile++) {
|
||||
let token = token_base + token_in_tile;
|
||||
if (token >= params.n_seq_tokens) {
|
||||
break;
|
||||
}
|
||||
|
||||
let x_dt = shared_x_dt[token_in_tile];
|
||||
let dA = exp(shared_dtsp[token_in_tile] * A0);
|
||||
let reduce_idx = reduce_base(token_in_tile) + tid;
|
||||
|
||||
let b_idx = params.offset_B + tid + g * params.stride_B1 + token * params.stride_B2 + i3 * params.stride_B3;
|
||||
let c_idx = params.offset_C + tid + g * params.stride_C1 + token * params.stride_C2 + i3 * params.stride_C3;
|
||||
let s = s_prev * dA + B[b_idx] * x_dt;
|
||||
s_prev = s;
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
let subgroup_partial = subgroupAdd(s * C[c_idx]);
|
||||
if (subgroup_invocation_id == 0u) {
|
||||
shared_reduce[reduce_idx - tid + subgroup_id] = subgroup_partial;
|
||||
}
|
||||
#else
|
||||
shared_reduce[reduce_idx] = s * C[c_idx];
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
if (tid == 0u) {
|
||||
var sum = 0.0;
|
||||
for (var sg = 0u; sg < num_subgroups; sg++) {
|
||||
sum += shared_reduce[reduce_base(token_in_tile) + sg];
|
||||
}
|
||||
let y_idx =
|
||||
params.offset_dst + i1 + ir * params.d_inner + token * (params.n_head * params.d_inner) +
|
||||
i3 * (params.n_seq_tokens * params.n_head * params.d_inner);
|
||||
dst[y_idx] = sum;
|
||||
}
|
||||
#else
|
||||
for (var stride = WG_SIZE / 2u; stride > 0u; stride >>= 1u) {
|
||||
if (tid < stride) {
|
||||
shared_reduce[reduce_idx] += shared_reduce[reduce_idx + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
if (tid == 0u) {
|
||||
let y_idx =
|
||||
params.offset_dst + i1 + ir * params.d_inner + token * (params.n_head * params.d_inner) +
|
||||
i3 * (params.n_seq_tokens * params.n_head * params.d_inner);
|
||||
dst[y_idx] = shared_reduce[reduce_base(token_in_tile)];
|
||||
}
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
}
|
||||
|
||||
let state_idx =
|
||||
params.offset_dst + params.y_elems + tid + i1 * params.d_state + ir * (params.d_state * params.d_inner) +
|
||||
i3 * (params.d_state * params.d_inner * params.n_head);
|
||||
dst[state_idx] = s_prev;
|
||||
}
|
||||
@@ -7656,7 +7656,7 @@ size_t ggml_quantize_chunk(
|
||||
int64_t nrows,
|
||||
int64_t n_per_row,
|
||||
const float * imatrix) {
|
||||
const int64_t n = (int64_t) nrows * n_per_row;
|
||||
const int64_t n = nrows * n_per_row;
|
||||
|
||||
if (ggml_quantize_requires_imatrix(type)) {
|
||||
GGML_ASSERT(imatrix != NULL);
|
||||
@@ -7673,21 +7673,21 @@ size_t ggml_quantize_chunk(
|
||||
size_t result = 0;
|
||||
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q1_0: result = quantize_q1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_TQ1_0: result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_TQ2_0: result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q1_0: result = quantize_q1_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_0: result = quantize_q4_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_1: result = quantize_q4_1 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q5_0: result = quantize_q5_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q5_1: result = quantize_q5_1 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q8_0: result = quantize_q8_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_MXFP4: result = quantize_mxfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_NVFP4: result = quantize_nvfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q2_K: result = quantize_q2_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q3_K: result = quantize_q3_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_K: result = quantize_q4_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q5_K: result = quantize_q5_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q6_K: result = quantize_q6_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_TQ1_0: result = quantize_tq1_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_TQ2_0: result = quantize_tq2_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
@@ -7752,9 +7752,9 @@ struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads) {
|
||||
}
|
||||
|
||||
bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) {
|
||||
if (p0->n_threads != p1->n_threads ) return false;
|
||||
if (p0->prio != p1->prio ) return false;
|
||||
if (p0->poll != p1->poll ) return false;
|
||||
if (p0->strict_cpu != p1->strict_cpu ) return false;
|
||||
if (p0->n_threads != p1->n_threads ) return false;
|
||||
if (p0->prio != p1->prio ) return false;
|
||||
if (p0->poll != p1->poll ) return false;
|
||||
if (p0->strict_cpu != p1->strict_cpu ) return false;
|
||||
return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0;
|
||||
}
|
||||
|
||||
@@ -3,8 +3,12 @@
|
||||
Test structured output capability via chat completions endpoint.
|
||||
|
||||
Each test case contains:
|
||||
- response_format: OpenAI-compatible response_format specification
|
||||
(json_schema only — llama.cpp does not support json_object)
|
||||
- response_format: OpenAI-compatible response_format specification.
|
||||
Both "json_schema" and "json_object" are accepted; with
|
||||
"json_object" a schema can be supplied via extra_body.
|
||||
- extra_body (optional): dict of extra top-level request fields merged into
|
||||
the request payload (mirrors the OpenAI SDK's extra_body
|
||||
feature; llama.cpp reads a top-level "json_schema" here).
|
||||
- messages: initial conversation messages
|
||||
- tools (optional): tool definitions (for mixed tool + structured tests)
|
||||
- mock_tool_responses (optional): dict mapping tool_name -> callable(arguments) -> str (JSON)
|
||||
@@ -81,11 +85,14 @@ def print_info(msg):
|
||||
_print(f"{DIM}{msg}{RESET}")
|
||||
|
||||
|
||||
def print_schema_note(label, rf):
|
||||
def print_schema_note(label, rf, extra_body=None):
|
||||
kind = rf.get("type", "?")
|
||||
name = ""
|
||||
if kind == "json_schema":
|
||||
name = rf.get("json_schema", {}).get("name", "")
|
||||
elif kind == "json_object" and extra_body and "json_schema" in extra_body:
|
||||
extra_schema = extra_body["json_schema"] or {}
|
||||
name = extra_schema.get("title") or "extra_body.json_schema"
|
||||
_print(f"{DIM}{MAGENTA} ⟐ response_format [{label}]: {kind}"
|
||||
f"{(' / ' + name) if name else ''}{RESET}")
|
||||
|
||||
@@ -95,17 +102,20 @@ def print_schema_note(label, rf):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def chat_completion(url, messages, tools=None, response_format=None, stream=False):
|
||||
def chat_completion(url, messages, tools=None, response_format=None, stream=False,
|
||||
extra_body=None):
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"max_tokens": 4096,
|
||||
"max_tokens": 8192,
|
||||
}
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["tool_choice"] = "auto"
|
||||
if response_format is not None:
|
||||
payload["response_format"] = response_format
|
||||
if extra_body:
|
||||
payload.update(extra_body)
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=payload, stream=stream)
|
||||
@@ -180,7 +190,7 @@ def chat_completion(url, messages, tools=None, response_format=None, stream=Fals
|
||||
|
||||
def run_tool_loop(
|
||||
url, messages, tools, mock_tool_responses, stream, response_format=None,
|
||||
max_turns=6,
|
||||
extra_body=None, max_turns=6,
|
||||
):
|
||||
"""
|
||||
Drive the tool-call loop. If response_format is provided it is applied to
|
||||
@@ -191,7 +201,8 @@ def run_tool_loop(
|
||||
|
||||
for _ in range(max_turns):
|
||||
result = chat_completion(
|
||||
url, msgs, tools=tools, response_format=response_format, stream=stream
|
||||
url, msgs, tools=tools, response_format=response_format, stream=stream,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
if result is None:
|
||||
return all_tool_calls, msgs, None
|
||||
@@ -274,7 +285,8 @@ def run_test(url, test_case, stream):
|
||||
print_header(f"{name} [{mode}] ({apply_stage})")
|
||||
|
||||
response_format = test_case["response_format"]
|
||||
print_schema_note(apply_stage, response_format)
|
||||
extra_body = test_case.get("extra_body")
|
||||
print_schema_note(apply_stage, response_format, extra_body)
|
||||
|
||||
tools = test_case.get("tools")
|
||||
mocks = test_case.get("mock_tool_responses") or {}
|
||||
@@ -290,6 +302,7 @@ def run_test(url, test_case, stream):
|
||||
mock_tool_responses=mocks,
|
||||
stream=stream,
|
||||
response_format=response_format,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
elif apply_stage == "after_tools":
|
||||
# Phase 1: plain tool loop, no response_format applied yet.
|
||||
@@ -314,7 +327,8 @@ def run_test(url, test_case, stream):
|
||||
# model focuses on producing the schema-constrained answer.
|
||||
_print(f"\n{DIM}{MAGENTA} ⟐ follow-up turn with response_format applied{RESET}")
|
||||
result = chat_completion(
|
||||
url, msgs, tools=None, response_format=response_format, stream=stream
|
||||
url, msgs, tools=None, response_format=response_format, stream=stream,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
final_content = result["content"] if result else None
|
||||
else:
|
||||
@@ -481,6 +495,51 @@ def _validate_sentiment(parsed):
|
||||
return True, f"sentiment={parsed['sentiment']} conf={conf} kws={kws}"
|
||||
|
||||
|
||||
# ---- Test: json_object + extra_body.json_schema (always) ----
|
||||
#
|
||||
# Exercises the llama.cpp-specific path where the OpenAI SDK would send
|
||||
# response_format={"type": "json_object"} and tunnel the schema through
|
||||
# extra_body.json_schema (which becomes a top-level "json_schema" field on
|
||||
# the request body).
|
||||
|
||||
_PRODUCT_JSON_OBJECT_SCHEMA = {
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"$id": "https://example.com/product.schema.json",
|
||||
"title": "Product",
|
||||
"description": "A product in the catalog",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
PRODUCT_JSON_OBJECT_TEST_CASE = {
|
||||
"name": "json_object response_format with extra_body json_schema",
|
||||
"response_format": {"type": "json_object"},
|
||||
"extra_body": {"json_schema": _PRODUCT_JSON_OBJECT_SCHEMA},
|
||||
"apply_stage": "always",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"Extract structured data from the provided text according to the "
|
||||
"JSON schema. Return only valid JSON matching the schema exactly."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Product: Wireless Headphones, ID: 101, In Stock: Yes",
|
||||
},
|
||||
],
|
||||
"validate": lambda parsed, tcs, raw: _validate_product_json_object(parsed),
|
||||
}
|
||||
|
||||
|
||||
def _validate_product_json_object(parsed):
|
||||
if not isinstance(parsed, dict):
|
||||
return False, f"expected JSON object, got {type(parsed).__name__}: {parsed!r}"
|
||||
if not parsed:
|
||||
return False, f"expected non-empty object, got {parsed!r}"
|
||||
return True, f"product object with {len(parsed)} field(s): {sorted(parsed.keys())}"
|
||||
|
||||
|
||||
# ---- Test 3: Nested recipe schema (always) ----
|
||||
|
||||
_RECIPE_SCHEMA = {
|
||||
@@ -915,6 +974,7 @@ def _validate_country_report(parsed, tcs):
|
||||
ALL_TEST_CASES = [
|
||||
BOOK_TEST_CASE,
|
||||
SENTIMENT_TEST_CASE,
|
||||
PRODUCT_JSON_OBJECT_TEST_CASE,
|
||||
RECIPE_TEST_CASE,
|
||||
SHOP_COMPARISON_TEST_CASE,
|
||||
COUNTRY_REPORT_TEST_CASE,
|
||||
|
||||
@@ -1283,7 +1283,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
llama_model_quantize_params llama_model_quantize_default_params() {
|
||||
llama_model_quantize_params result = {
|
||||
/*.nthread =*/ 0,
|
||||
/*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
|
||||
/*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q8_0,
|
||||
/*.output_tensor_type =*/ GGML_TYPE_COUNT,
|
||||
/*.token_embedding_type =*/ GGML_TYPE_COUNT,
|
||||
/*.allow_requantize =*/ false,
|
||||
|
||||
@@ -1331,7 +1331,7 @@ static void test_nemotron_reasoning_detection(testing & t) {
|
||||
|
||||
// Check reasoning markers
|
||||
t.assert_equal("reasoning_start should be '<think>\\n'", "<think>\n", analysis.reasoning.start);
|
||||
t.assert_equal("reasoning_end should be '</think>'", "</think>", analysis.reasoning.end);
|
||||
t.assert_equal("reasoning_end should be '\\n</think>\\n'", "\n</think>\n", analysis.reasoning.end);
|
||||
|
||||
// Check reasoning mode detection
|
||||
// Nemotron uses tag-based reasoning; prefill handles the template's forced markers
|
||||
|
||||
@@ -1642,22 +1642,16 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
// Qwen3.5 (basically same as Nemotron, but keeping separate tests just in case)
|
||||
auto tst = peg_tester("models/templates/Qwen3.5-4B.jinja", detailed_debug);
|
||||
|
||||
tst.test("I'm\nthinking</think>Hello, world!\nWhat's up?")
|
||||
tst.test("I'm\nthinking\n</think>\n\nHello, world!\nWhat's up?")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.enable_thinking(true)
|
||||
.expect(message_assist_thoughts)
|
||||
.run();
|
||||
|
||||
tst.test("I'm\nthinking\n</think>\nHello, world!\nWhat's up?")
|
||||
tst.test("I'm\nthinking\n</think>\n\nHello, world!\nWhat's up?")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_NONE)
|
||||
.expect_content("<think>\nI'm\nthinking\n</think>\nHello, world!\nWhat's up?")
|
||||
.run();
|
||||
|
||||
tst.test("I'm\nthinking\n</think>\nHello, world!\nWhat's up?")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist_thoughts)
|
||||
.expect_content("<think>\nI'm\nthinking\n</think>\n\nHello, world!\nWhat's up?")
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
@@ -1673,7 +1667,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
"I'm\nthinking\n</think>\n"
|
||||
"I'm\nthinking\n</think>\n\n"
|
||||
"<tool_call>\n"
|
||||
"<function=special_function>\n"
|
||||
"<parameter=arg1>\n1\n</parameter>\n"
|
||||
@@ -1731,7 +1725,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
|
||||
tst.test(
|
||||
"I need to output the invoice details in JSON\n"
|
||||
"</think>\n"
|
||||
"</think>\n\n"
|
||||
R"({"amount": 123.45, "date": "2025-12-03"})")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.enable_thinking(true)
|
||||
@@ -1751,7 +1745,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
"hello()\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call></think>\n"
|
||||
"</tool_call>\n</think>\n\n"
|
||||
"<tool_call>\n"
|
||||
"<function=python>\n"
|
||||
"<parameter=code>\n"
|
||||
@@ -1994,7 +1988,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
"hello()\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call></think>\n"
|
||||
"</tool_call>\n</think>\n"
|
||||
"<tool_call>\n"
|
||||
"<function=python>\n"
|
||||
"<parameter=code>\n"
|
||||
@@ -3463,7 +3457,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.run();
|
||||
|
||||
// Tool call with reasoning (enable_thinking=true)
|
||||
tst.test("I'm\nthinking</think><tool_call>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}</tool_call>")
|
||||
tst.test("I'm\nthinking\n</think>\n\n<tool_call>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}</tool_call>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ special_function_tool })
|
||||
@@ -3487,7 +3481,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.run();
|
||||
|
||||
// Tool call with reasoning and content
|
||||
tst.test("I need to call a function</think>"
|
||||
tst.test("I need to call a function\n</think>\n\n"
|
||||
"Let me check the time.<tool_call>\n{\"name\": \"get_time\", \"arguments\": {\"city\": \"XYZCITY\"}}</tool_call>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
@@ -3514,7 +3508,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
|
||||
// fake tool call marker in reasoning
|
||||
tst.test(
|
||||
"Let me think about <tool_call>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 2}}</tool_call> hmm</think>"
|
||||
"Let me think about <tool_call>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 2}}</tool_call> hmm\n</think>\n\n"
|
||||
"<tool_call>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}</tool_call>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
@@ -3542,11 +3536,11 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
// Format: <minimax:tool_call><invoke name="func"><parameter name="key">value</parameter></invoke></minimax:tool_call>
|
||||
{
|
||||
auto tst = peg_tester("models/templates/MiniMax-M2.jinja", detailed_debug);
|
||||
tst.test("</think>Hello, world!\nWhat's up?").enable_thinking(true).reasoning_format(COMMON_REASONING_FORMAT_AUTO).expect(message_assist).run();
|
||||
tst.test("\n</think>\n\nHello, world!\nWhat's up?").enable_thinking(true).reasoning_format(COMMON_REASONING_FORMAT_AUTO).expect(message_assist).run();
|
||||
|
||||
tst.test("I'm\nthinking</think>Hello, world!\nWhat's up?").enable_thinking(true).reasoning_format(COMMON_REASONING_FORMAT_AUTO).expect(message_assist_thoughts).run();
|
||||
tst.test("I'm\nthinking\n</think>\n\nHello, world!\nWhat's up?").enable_thinking(true).reasoning_format(COMMON_REASONING_FORMAT_AUTO).expect(message_assist_thoughts).run();
|
||||
|
||||
tst.test("Let's call a tool:</think><minimax:tool_call>\n<invoke name=\"empty_args\">\n</invoke>\n</minimax:tool_call>").
|
||||
tst.test("Let's call a tool:\n</think>\n\n<minimax:tool_call>\n<invoke name=\"empty_args\">\n</invoke>\n</minimax:tool_call>").
|
||||
enable_thinking(true).
|
||||
reasoning_format(COMMON_REASONING_FORMAT_AUTO).
|
||||
tools({ empty_args_tool }).
|
||||
@@ -3554,7 +3548,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
run();
|
||||
|
||||
tst.test(
|
||||
"</think><minimax:tool_call>\n<invoke name=\"special_function\">\n<parameter "
|
||||
"\n</think>\n\n<minimax:tool_call>\n<invoke name=\"special_function\">\n<parameter "
|
||||
"name=\"arg1\">1</parameter>\n</invoke>\n</minimax:tool_call>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
@@ -3714,7 +3708,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.enable_thinking(false)
|
||||
.expect(message_assist)
|
||||
.run();
|
||||
tst.test("I'm\nthinking</think>\n\nHello, world!\nWhat's up?")
|
||||
tst.test("I'm\nthinking\n</think>\n\nHello, world!\nWhat's up?")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.expect(message_assist_thoughts)
|
||||
@@ -3729,7 +3723,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call_content)
|
||||
.run();
|
||||
tst.test("I'm\nthinking</think>\n\n<tool_call>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n</tool_call>")
|
||||
tst.test("I'm\nthinking\n</think>\n\n<tool_call>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n</tool_call>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ special_function_tool })
|
||||
@@ -4006,7 +4000,8 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
|
||||
{
|
||||
auto tst = peg_tester("models/templates/StepFun3.5-Flash.jinja", detailed_debug);
|
||||
tst.test("I was thinking</think>\nNow I'm not.").
|
||||
|
||||
tst.test("I was thinking\n</think>\nNow I'm not.").
|
||||
enable_thinking(true).
|
||||
reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK).
|
||||
expect_reasoning("I was thinking").
|
||||
|
||||
@@ -947,7 +947,9 @@ json oaicompat_chat_params_parse(
|
||||
json response_format = json_value(body, "response_format", json::object());
|
||||
std::string response_type = json_value(response_format, "type", std::string());
|
||||
if (response_type == "json_object") {
|
||||
json_schema = json_value(response_format, "schema", json::object());
|
||||
if (response_format.contains("schema") || json_schema.empty()) {
|
||||
json_schema = json_value(response_format, "schema", json::object());
|
||||
}
|
||||
} else if (response_type == "json_schema") {
|
||||
auto schema_wrapper = json_value(response_format, "json_schema", json::object());
|
||||
json_schema = json_value(schema_wrapper, "schema", json::object());
|
||||
|
||||
Reference in New Issue
Block a user