Compare commits

..

80 Commits

Author SHA1 Message Date
Georgi Gerganov
7a3c178d78 speculative : adapt to new llama API
ggml-ci
2025-03-18 22:05:44 +02:00
Xuan Son Nguyen
dc4bb64290 Merge branch 'master' into xsn/private_batch_api 2025-03-18 15:45:22 +01:00
Georgi Gerganov
8551c44d84 context : always use non-causal attention for encoder graphs (#12447)
* context : always use non-causal attention for encoder graphs

ggml-ci

* context : move the change to llama_context::encode()

ggml-ci
2025-03-18 13:05:49 +02:00
Łukasz Ślusarczyk
35cae5ba05 SYCL: using graphs is configurable by environment variable and compile option (#12371)
* alberto changes

* enable sycl graphs by env variable

* fixed compilation warnings in ggml-sycl.cpp

* renamed graph variables

* fix markdown in docs/backend/SYCL.md

Co-authored-by: Romain Biessy <romain.biessy@codeplay.com>

* fix markdown in docs/backend/SYCL.md again

* compiling graphs by default, renamed graph_enable to graph_disable

---------

Co-authored-by: Romain Biessy <romain.biessy@codeplay.com>
2025-03-18 11:16:31 +01:00
Georgi Gerganov
810e0af3f5 server : fix warmup draft cache type (#12446)
ggml-ci
2025-03-18 12:05:42 +02:00
Prajwal B Mehendarkar
eba92d64c3 cmake : fix PowerPC build (#12241)
Closes #12240
2025-03-18 11:37:33 +02:00
fj-y-saito
d9a14523bb ggml : add SVE support for q6_K_q8_K (#12361) 2025-03-18 10:14:39 +02:00
0cc4m
fd123cfead Vulkan: Default to 1GB allocations instead of 4GB to avoid fragmentation and driver issues (#12434) 2025-03-18 07:21:40 +01:00
Łukasz Ślusarczyk
a53f7f7b88 fixed compilation warnings in ggml-sycl (#12424)
Some checks are pending
Python check requirements.txt / check-requirements (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
Python Type-Check / pyright type-check (push) Waiting to run
2025-03-18 08:51:25 +08:00
Molly Sophia
7dfad387e3 llama: Add support for RWKV v7 architecture (#12412)
* ggml: Add op l2_norm

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* ggml: Add op rwkv_wkv7

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: Add support for RWKV7 and ARWKV7 models

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: fix inference with RWKV6Qwen2

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: add more (a)rwkv7 variants in size

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Apply code-format changes

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* fix MUSA build

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: fix shape error with rwkv using llama-parallel

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

---------

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
2025-03-18 07:27:50 +08:00
Sigbjørn Skjæret
60c902926c docs : bring llama-cli conversation/template docs up-to-date (#12426) 2025-03-17 21:14:32 +01:00
Gaurav Garg
b1b132efcb cuda : enable CUDA Graph on CUDA Toolkit < 12.x (#12394)
* Enable CUDA Graph on CTK < 12.x

`cudaGraphExecUpdate` API was changed on 12.x. For this reason CUDA graph support was disabled on older CUDA toolkit. This change enables CUDA support in CTK version < 12.x by using older API if CTK < 12.x.

* Fix compilation errors with MUSA

* Disable CUDA Graph for MUSA
2025-03-17 20:25:13 +02:00
Guus Waals
01e8f2138b ggml-vulkan: remove unused find_program(glslc) (#12416)
It's already found by FindVulkan.cmake in the parent CMakeLists
2025-03-17 13:35:43 -03:00
Jeff Bolz
484a8ab513 vulkan: Add N/2 and N/4 optimized paths in coopmat2 shader (#12312) 2025-03-17 09:26:18 -05:00
Daniele
cf2270e4d3 vulkan: subgroup size tuning (#12087)
* vulkan: subgroup size test

* Vulkan: Add device architecture enum and logic to recognize AMD generations

* vulkan: use new architecture logic to specify subgroup size

* Initial vulkan subgroup size tuning for RDNA3

* vulkan: commonize RDNA subgroup tuning

* vulkan: override subgroup size if required_subgroup_size = 0

* vulkan: disable warp 32 for RDNA3

* vulkan: fine tuned RDNA1 subgroup sizes

* vulkan: adjusted subgroup size map

* vulkan: fixed RDNA2 subgroup map

---------

Co-authored-by: 0cc4m <picard12@live.de>
2025-03-17 12:42:33 +01:00
Xuan-Son Nguyen
eab5606d7b Apply suggestions from code review 2025-03-17 12:17:14 +01:00
Xuan-Son Nguyen
de788e071b Update examples/tts/tts.cpp
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-03-17 12:05:23 +01:00
Jeff Bolz
f07690c930 vulkan: use fp32 in coopmat2 q4_k dequant function (#12309) 2025-03-17 10:43:35 +01:00
Jeff Bolz
891c63956d vulkan: Pad N dimension of B matrix for coopmat2 perf, to avoid bounds checking (#12273)
* vulkan: Pad N dimension of B matrix for coopmat2 perf, to avoid bounds checking
2025-03-17 10:41:59 +01:00
Jeff Bolz
2f21123c1d vulkan: Adjust coopmat2 tile sizes and selection heuristic (#12258) 2025-03-17 10:35:00 +01:00
Christian Kastner
374101fd74 cmake : enable building llama.cpp using system libggml (#12321)
* cmake: Factor out compiler flag function from ggml

llama.cpps's build requires it, too, and we may want to make use of it
without add_subdirectory(ggml).

* cmake: Enable building against system ggml

This facilitates package maintenance for Linux distributions, where the
libggml library most likely will be shipped as an individual package
upon which a llama.cpp package depends.
2025-03-17 11:05:23 +02:00
Akarshan Biswas
b3c9a65673 SYCL: set extras only on GGML_TYPE_Q4_0 (#12366)
* SYCL: set extras only on GGML_TYPE_Q4_0

* release tensor_extras in reset buffer interface
2025-03-17 09:45:12 +08:00
Sigbjørn Skjæret
8ba95dca20 llama : fix OLMo-2-0325-32B-Instruct K-norm size (#12400) 2025-03-16 19:46:36 +02:00
Georgi Gerganov
dc079cfdff context : fix init of n_outputs (#12397)
ggml-ci
2025-03-16 19:29:36 +02:00
Daniel Bevenius
7b61bcc87c ci : add --symlinks to xcframework zip command (#12409)
This commit adds the --symlinks option to the zip command used to create
the xcframework zip file. This is necessary to create symlinks in the
zip file. Without this option,  the Versions symlink is stored as a
regular directory entry in the zip file, rather than as a symlink in the
zip which causes the followig error in xcode:
```console
Couldn't resolve framework symlink for '/Users/danbev/work/ai/llama.cpp/tmp_1/build-apple/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current': readlink(/Users/danbev/work/ai/llama.cpp/tmp_1/build-apple/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current): Invalid argument (22)
```

Refs: https://github.com/ggml-org/llama.cpp/pull/11996#issuecomment-2727026377
2025-03-16 18:22:05 +01:00
marcoStocchi
f4c3dd5daa llama-tts : add '-o' option (#12398)
* added -o option to specify an output file name

* llama-tts returns ENOENT in case of file write error

note : PR #12042 is closed as superseded with this one.
2025-03-15 17:23:11 +01:00
aubreyli
3d35d87b41 SYCL: Delete redundant plus sign and space (#12391) 2025-03-15 15:49:03 +01:00
fairydreaming
b19bd064c0 SYCL : support non-contiguous tensors in binary ops (add, sub, etc) (#12399)
* sycl : support non-contiguous tensors in binary ops

* sycl : silence unused variable warning

---------

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
2025-03-15 22:19:30 +08:00
Chenguang Li
92a391327e [CANN]MUL_MAT optimization (#12382) 2025-03-15 09:31:08 +08:00
Xuan Son Nguyen
624a683c6f fix compile 2025-03-14 22:30:29 +01:00
Xuan Son Nguyen
116b9a1662 rename to init_from_text 2025-03-14 22:17:07 +01:00
Eric Curtin
9f2250ba72 Add CLI arg to llama-run to adjust the number of threads used (#12370)
We default to 4, sometimes we want to manually adjust this

Signed-off-by: Eric Curtin <ecurtin@redhat.com>
2025-03-14 16:41:20 +00:00
Xuan Son Nguyen
eaffba0f2e llama_batch_ext_ptr::from_text/embd 2025-03-14 17:12:03 +01:00
Sigbjørn Skjæret
774973b8f3 main : add -sysf / --system-prompt-file (#12249) (#12250)
* add system_prompt_file

* add -sysf / --system-prompt-file

* remove system_prompt_file
2025-03-14 16:57:05 +01:00
fairydreaming
8fcb563613 Load all MoE experts during warmup (#11571)
* llama : introduce llama_set_warmup() API call that controls warmup mode; use all MoE experts during warmup

* common : use new API to enable warmup mode during model warmup

---------

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
2025-03-14 13:47:05 +01:00
Xuan Son Nguyen
8e7714fa77 fix compile 2025-03-14 11:28:15 +01:00
Xuan Son Nguyen
a363251fac qwen2vl: use llama_batch_ext_set_pos 2025-03-14 11:25:36 +01:00
Victor
add2a3aa5a server: fix "--grammar-file" parameter (#12285) 2025-03-14 11:21:17 +01:00
Xuan Son Nguyen
ba79369615 fix llama_batch_ext_init_from_embd 2025-03-14 11:17:22 +01:00
Xuan Son Nguyen
07d84fa3c2 fix missing n_past in various places
this is actually a revert of cda0e4b648
2025-03-14 10:47:08 +01:00
Xuan Son Nguyen
32940369d3 fix gemma3-cli 2025-03-14 10:33:28 +01:00
Xuan Son Nguyen
5e6a6d4e1c fix llama-run n_past 2025-03-14 10:32:43 +01:00
Georgi Gerganov
c522ce4143 graph : simplify attn input build for unified KV cache (#12381)
ggml-ci
2025-03-14 10:47:44 +02:00
Georgi Gerganov
081bee8c64 hparams : add SWA rope parameters (#12374)
ggml-ci
2025-03-14 09:03:24 +02:00
Xuan Son Nguyen
bfdddbc150 bring back mistakenly deleted llama_batch_init/free 2025-03-14 00:22:28 +01:00
Xuan Son Nguyen
54566ad95d correct comment 2025-03-14 00:21:06 +01:00
Xuan Son Nguyen
04f8641815 rm redundant llama_batch_ext_set_output_last 2025-03-13 23:14:16 +01:00
Xuan Son Nguyen
c3dd79007b fix llama_batch_ext_init_from_text 2025-03-13 23:09:27 +01:00
Xuan Son Nguyen
65f0184517 compile ok 2025-03-13 22:56:35 +01:00
Xuan Son Nguyen
9fb2d81eab fix common_batch missing seq_id 2025-03-13 22:38:04 +01:00
Xuan Son Nguyen
47086fa82d apply to the rest 2025-03-13 22:36:27 +01:00
Georgi Gerganov
84d5475541 llama : fix Gemma3 SWA KV cache shift (#12373)
* llama : fix Gemma3 SWA KV cache shift

ggml-ci

* hparams : add comment [no ci]
2025-03-13 19:08:07 +02:00
Xuan Son Nguyen
4aabf4e8f4 return output ID from llama_batch_ext_add/set 2025-03-13 17:47:07 +01:00
Xuan Son Nguyen
86973cb14a fix merge errors 2025-03-13 17:32:36 +01:00
Xuan Son Nguyen
17f954c8e2 Merge branch 'master' into xsn/private_batch_api 2025-03-13 15:55:18 +01:00
Xuan-Son Nguyen
be7c303410 arg : no n_predict = -2 for examples except for main and infill (#12364) 2025-03-13 12:34:54 +01:00
Georgi Gerganov
e0dbec0bc6 llama : refactor llama_context, llama_kv_cache, llm_build_context (#12181)
Some checks failed
flake8 Lint / Lint (push) Has been cancelled
Python Type-Check / pyright type-check (push) Has been cancelled
* llama : refactor llama_context, llama_kv_cache, llm_build_context

ggml-ci

* graph : don't mutate the KV cache during defrag

ggml-ci

* context : reduce virtuals + remove test function

ggml-ci

* context : move interface implementation to source file + factory

ggml-ci

* graph : move KV cache build functions to llama_context impl

ggml-ci

* graph : remove model reference from build_pooling

ggml-ci

* graph : remove llama_model reference

ggml-ci

* kv_cache : provide rope factors

ggml-ci

* graph : rework inputs to use only unique_ptr, remove attn input abstraction

ggml-ci

* context : remove llama_context_i abstraction

ggml-ci

* context : clean-up

ggml-ci

* graph : clean-up

ggml-ci

* llama : remove redundant keywords (struct, enum)

ggml-ci

* model : adapt gemma3

ggml-ci

* graph : restore same attention ops as on master

ggml-ci

* llama : remove TODO + fix indent

ggml-ci
2025-03-13 12:35:44 +02:00
Ishaan Gandhi
2048b5913d server : fix crash when using verbose output with input tokens that are not in printable range (#12178) (#12338)
* Fix DOS index bug

* Remove new APIs

* remove extra line

* Remove from API

* Add extra newline

* Update examples/server/server.cpp

---------

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2025-03-13 11:10:05 +01:00
Oscar Barenys
f08f4b3187 Update build.yml for Windows Vulkan builder to use Vulkan 1.4.304 SDK for VK_NV_cooperative_matrix2 support (#12301) 2025-03-12 20:06:58 +01:00
Daniel Bevenius
80a02aa858 llama.swiftui : fix xcframework dir in README [no ci] (#12353)
This commit fixes the path to the xcframework in the README file which I
had forgotten to change after renaming the build directory.
2025-03-12 13:45:32 +01:00
Alberto Cabrera Pérez
363f8c5d67 sycl : variable sg_size support for mmvq kernels (#12336)
Some checks failed
flake8 Lint / Lint (push) Waiting to run
Python Type-Check / pyright type-check (push) Waiting to run
Python check requirements.txt / check-requirements (push) Has been cancelled
2025-03-12 09:57:32 +00:00
uvos
34c961b181 CUDA/HIP: Fix fattn-vec-* when device warp size is not 32 (#12315)
When fattn-wmma was ported over to warp64 various bits that also touch fattn-vec where converted to
selectable warp size, however the fattn-vec kernels dont work with 64 wide warps for now, so we need
to avoid launching them with parameters for warp64
2025-03-12 10:14:11 +01:00
Xuan-Son Nguyen
7841fc723e llama : Add Gemma 3 support (+ experimental vision capability) (#12343)
* llama : Add Gemma 3 text-only support

* fix python coding style

* fix compile on ubuntu

* python: fix style

* fix ubuntu compile

* fix build on ubuntu (again)

* fix ubuntu build, finally

* clip : Experimental support for Gemma 3 vision (#12344)

* clip : Experimental support for Gemma 3 vision

* fix build

* PRId64
2025-03-12 09:30:24 +01:00
Jeff Bolz
bf69cfe62f vulkan: fix bug in coopmat1 mul_mat_id (#12316)
* tests: run mul_mat_id with a larger N

* vulkan: fix bug in coopmat1 mul_mat_id
2025-03-12 06:59:19 +01:00
uvos
10f2e81809 CUDA/HIP: refractor mmqv to unify the calculation of nwarps and rows per block between host and device code. (#12177)
refactor mmqv to unify the calculation of nwarps and rows per block between host and device code.

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2025-03-11 20:16:03 +01:00
jklincn
ba7654380a ggml-backend : fix backend search path (#12330)
* Fix backend search path

* replace .native() with '/'

* reverted .native()
2025-03-11 14:25:17 +01:00
BB-fat
6ab2e4765a metal : Cache the Metal library at the device context level (#12265) 2025-03-11 13:45:02 +02:00
Xuan Son Nguyen
46596caf6d apply various in places 2025-03-01 20:42:18 +01:00
Xuan Son Nguyen
1d6ba97789 remove token_info API 2025-03-01 16:21:16 +01:00
Xuan Son Nguyen
1170135dfb llama_batch_ext_add_text 2025-03-01 14:00:14 +01:00
Xuan Son Nguyen
40989f4116 correct llama_decode_ext 2025-03-01 14:00:05 +01:00
Xuan Son Nguyen
9e75c49d35 Merge branch 'master' into xsn/private_batch_api 2025-03-01 12:13:03 +01:00
Xuan Son Nguyen
f0ffd81130 adapt common 2025-03-01 12:12:52 +01:00
Xuan Son Nguyen
a1b1dea33b Merge branch 'master' into xsn/private_batch_api 2025-02-24 17:01:30 +01:00
Xuan Son Nguyen
4bf7ca3943 llama_decode_ext 2025-02-24 17:01:20 +01:00
Xuan Son Nguyen
aed4a8e980 fix server 2025-02-16 11:36:50 +01:00
Xuan Son Nguyen
85ef80cbe9 server : use llama_batch_ext 2025-02-16 00:06:48 +01:00
Xuan Son Nguyen
17d3658b5f move to llama_batch_ext 2025-02-16 00:02:53 +01:00
Xuan Son Nguyen
f2e59a8eb9 rework, targeting llama-server 2025-02-14 18:16:49 +01:00
Xuan Son Nguyen
4ed4fe75ed first proposal for private llama_batch 2025-02-14 00:48:12 +01:00
132 changed files with 20089 additions and 13634 deletions

View File

@@ -774,7 +774,7 @@ jobs:
env:
OPENBLAS_VERSION: 0.3.23
SDE_VERSION: 9.33.0-2024-01-07
VULKAN_VERSION: 1.3.261.1
VULKAN_VERSION: 1.4.304.1
strategy:
matrix:
@@ -1379,7 +1379,7 @@ jobs:
id: pack_artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
run: |
zip -r llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework
zip --symlinks -r llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework
- name: Upload artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}

View File

@@ -29,6 +29,8 @@ else()
set(LLAMA_STANDALONE OFF)
endif()
option(LLAMA_USE_SYSTEM_GGML "Use system libggml" OFF)
if (EMSCRIPTEN)
set(BUILD_SHARED_LIBS_DEFAULT OFF)
@@ -145,7 +147,13 @@ endif()
# 3rd-party
#
if (NOT TARGET ggml)
if (LLAMA_USE_SYSTEM_GGML)
message(STATUS "Using system-provided libggml, skipping ggml build")
find_package(ggml REQUIRED)
add_library(ggml ALIAS ggml::ggml)
endif()
if (NOT TARGET ggml AND NOT LLAMA_USE_SYSTEM_GGML)
add_subdirectory(ggml)
# ... otherwise assume ggml is added by a parent CMakeLists.txt
endif()

View File

@@ -1,3 +1,5 @@
include("ggml/cmake/common.cmake")
function(llama_add_compile_flags)
if (LLAMA_FATAL_WARNINGS)
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")

View File

@@ -764,7 +764,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_env("LLAMA_ARG_CTX_SIZE"));
add_opt(common_arg(
{"-n", "--predict", "--n-predict"}, "N",
string_format("number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict),
string_format(
ex == LLAMA_EXAMPLE_MAIN || ex == LLAMA_EXAMPLE_INFILL
? "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)"
: "number of tokens to predict (default: %d, -1 = infinity)",
params.n_predict),
[](common_params & params, int value) {
params.n_predict = value;
}
@@ -849,6 +853,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_excludes({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-sysf", "--system-prompt-file"}, "FNAME",
"a file containing the system prompt (default: none)",
[](common_params & params, const std::string & value) {
std::ifstream file(value);
if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
}
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.system_prompt));
if (!params.system_prompt.empty() && params.system_prompt.back() == '\n') {
params.system_prompt.pop_back();
}
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(common_arg(
{"--in-file"}, "FNAME",
"an input file (repeat to specify multiple files)",
@@ -1871,7 +1889,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
params.out_file = value;
}
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA}));
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS}));
add_opt(common_arg(
{"-ofreq", "--output-frequency"}, "N",
string_format("output the imatrix every N iterations (default: %d)", params.n_out_freq),

View File

@@ -582,41 +582,6 @@ std::string string_from(const struct llama_context * ctx, const std::vector<llam
return buf.str();
}
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) {
std::stringstream buf;
buf << "[ ";
bool first = true;
for (int i = 0; i < batch.n_tokens; ++i) {
if (!first) {
buf << ", ";
} else {
first = false;
}
auto detokenized = common_token_to_piece(ctx, batch.token[i]);
detokenized.erase(
std::remove_if(
detokenized.begin(),
detokenized.end(),
[](const unsigned char c) { return !std::isprint(c); }),
detokenized.end());
buf << "\n" << std::to_string(i)
<< ", token '" << detokenized << "'"
<< ", pos " << std::to_string(batch.pos[i])
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
<< ", logits " << std::to_string(batch.logits[i]);
}
buf << " ]";
return buf.str();
}
void string_process_escapes(std::string & input) {
std::size_t input_len = input.length();
std::size_t output_idx = 0;
@@ -955,8 +920,8 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams;
}
if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) {
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
params.ctx_shift = false;
}
@@ -1033,6 +998,8 @@ struct common_init_result common_init_from_params(common_params & params) {
if (params.warmup) {
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
llama_set_warmup(lctx, true);
std::vector<llama_token> tmp;
llama_token bos = llama_vocab_bos(vocab);
llama_token eos = llama_vocab_eos(vocab);
@@ -1049,7 +1016,8 @@ struct common_init_result common_init_from_params(common_params & params) {
}
if (llama_model_has_encoder(model)) {
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), tmp.size(), 0, 0, true));
llama_encode_ext(lctx, batch.get());
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
decoder_start_token_id = bos;
@@ -1058,11 +1026,13 @@ struct common_init_result common_init_from_params(common_params & params) {
tmp.push_back(decoder_start_token_id);
}
if (llama_model_has_decoder(model)) {
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true));
llama_decode_ext(lctx, batch.get());
}
llama_kv_cache_clear(lctx);
llama_kv_self_clear(lctx);
llama_synchronize(lctx);
llama_perf_context_reset(lctx);
llama_set_warmup(lctx, false);
}
iparams.model.reset(model);
@@ -1610,10 +1580,12 @@ std::pair<std::string, std::string> common_get_hf_file(const std::string &, cons
// Batch utils
//
// DEPRECATED
void common_batch_clear(struct llama_batch & batch) {
batch.n_tokens = 0;
}
// DEPRECATED
void common_batch_add(
struct llama_batch & batch,
llama_token id,

View File

@@ -516,7 +516,6 @@ void string_process_escapes(std::string & input);
std::string string_from(bool value);
std::string string_from(const std::vector<int> & values);
std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens);
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch);
//
// Filesystem utils
@@ -570,8 +569,10 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adap
// Batch utils
//
// DEPRECATED
void common_batch_clear(struct llama_batch & batch);
// DEPRECATED
void common_batch_add(
struct llama_batch & batch,
llama_token id,
@@ -579,6 +580,66 @@ void common_batch_add(
const std::vector<llama_seq_id> & seq_ids,
bool logits);
// convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions
// this is meant to be temporary
struct common_batch {
llama_batch_ext_ptr batch;
struct batch_token {
llama_token token;
llama_seq_id seq_id; // only support single seq for now
bool logits;
};
std::vector<batch_token> tokens;
int n_outputs = 0;
common_batch() = default;
common_batch(int32_t n_tokens, int32_t n_seq_max) {
batch.reset(llama_batch_ext_init(n_tokens, n_seq_max));
tokens.reserve(n_tokens);
}
void clear() {
llama_batch_ext_clear(batch.get());
tokens.clear();
}
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
tokens.push_back({token, seq_id, logits});
if (logits) {
n_outputs++;
}
}
void add_text_multi_seq(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) {
llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits);
tokens.push_back({token, seq_ids[0], logits});
if (logits) {
n_outputs++;
}
}
void set_logits_last() {
if (!tokens.empty()) {
llama_batch_ext_set_output_last(batch.get());
tokens.back().logits = true;
}
}
int32_t get_n_tokens() const {
return (int32_t)tokens.size();
}
llama_batch_ext * get() {
return batch.get();
}
common_batch get_view(int32_t offset, int32_t n_tokens) {
common_batch view;
view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens));
view.tokens.reserve(n_tokens);
for (int32_t i = 0; i < n_tokens; i++) {
view.tokens.push_back(tokens[offset + i]);
if (tokens[offset + i].logits) {
view.n_outputs++;
}
}
return view;
}
};
//
// Token utils
//

View File

@@ -14,7 +14,7 @@ struct common_speculative {
struct llama_context * ctx;
struct common_sampler * smpl;
llama_batch batch;
llama_batch_ext_ptr batch;
llama_tokens prompt;
};
@@ -23,7 +23,7 @@ struct common_speculative * common_speculative_init(
auto * result = new common_speculative {
/* .ctx = */ ctx_dft,
/* .smpl = */ nullptr,
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
/* .batch = */ llama_batch_ext_ptr(llama_batch_ext_init(llama_n_batch(ctx_dft), 1)),
/* .prompt = */ {},
};
@@ -69,8 +69,6 @@ void common_speculative_free(struct common_speculative * spec) {
common_sampler_free(spec->smpl);
llama_batch_free(spec->batch);
delete spec;
}
@@ -151,6 +149,8 @@ llama_tokens common_speculative_gen_draft(
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
const llama_seq_id seq_id = 0;
// reuse as much as possible from the old draft context
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
for (int i = 0; i < (int) prompt.size(); ++i) {
@@ -173,7 +173,7 @@ llama_tokens common_speculative_gen_draft(
result.reserve(params.n_draft);
if (reuse_n == 0) {
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
prompt.clear();
} else {
@@ -192,54 +192,54 @@ llama_tokens common_speculative_gen_draft(
}
if (reuse_i > 0) {
llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
llama_kv_self_seq_rm (ctx, 0, 0, reuse_i);
llama_kv_self_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
}
if (reuse_n < (int) prompt.size()) {
llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
llama_kv_self_seq_rm (ctx, 0, reuse_n, -1);
prompt.erase(prompt.begin() + reuse_n, prompt.end());
}
}
// prepare a batch to evaluate any new tokens in the prompt
common_batch_clear(batch);
llama_batch_ext_clear(batch.get());
for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
llama_batch_ext_add_text(batch.get(), prompt_tgt[i], i - i_start, &seq_id, 1, false);
prompt.push_back(prompt_tgt[i]);
}
// we should rarely end-up here during normal decoding
if (batch.n_tokens > 0) {
if (llama_batch_ext_get_n_tokens(batch.get()) > 0) {
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
llama_decode(ctx, batch);
llama_decode_ext(ctx, batch.get());
}
const llama_pos n_past = prompt.size();
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
common_batch_clear(batch);
common_batch_add (batch, id_last, n_past, { 0 }, true);
llama_batch_ext_clear(batch.get());
llama_batch_ext_add_text(batch.get(), id_last, n_past, &seq_id, 1, true);
prompt.push_back(id_last);
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
llama_decode(ctx, batch);
llama_decode_ext(ctx, batch.get());
common_sampler_reset(smpl);
// sample n_draft tokens from the draft model
for (int i = 0; i < params.n_draft; ++i) {
common_batch_clear(batch);
llama_batch_ext_clear(batch.get());
common_sampler_sample(smpl, ctx, 0, true);
@@ -266,10 +266,10 @@ llama_tokens common_speculative_gen_draft(
break;
}
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
llama_batch_ext_add_text(batch.get(), id, n_past + i + 1, &seq_id, 1, true);
// evaluate the drafted tokens on the draft model
llama_decode(ctx, batch);
llama_decode_ext(ctx, batch.get());
prompt.push_back(id);
}

View File

@@ -861,6 +861,9 @@ class Model:
for token_id, token_data in added_tokens_decoder.items():
token_id = int(token_id)
token: str = token_data["content"]
if token_id >= vocab_size:
logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
continue
if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
if tokens[token_id] != token.encode("utf-8"):
logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token!r}')
@@ -905,6 +908,40 @@ class Model:
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
def _set_vocab_rwkv_world(self):
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
vocab_size = self.hparams.get("vocab_size", 65536)
tokens: list[bytes] = ['<s>'.encode("utf-8")]
toktypes: list[int] = [gguf.TokenType.CONTROL]
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
parts = line.split(' ')
assert len(parts) >= 3
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
token = token.encode("utf-8") if isinstance(token, str) else token
assert isinstance(token, bytes)
assert len(token) == token_len
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
tokens.append(token_text.encode("utf-8"))
toktypes.append(gguf.TokenType.NORMAL)
remainder = vocab_size - len(tokens)
assert remainder >= 0
for i in range(len(tokens), vocab_size):
tokens.append(f"[PAD{i}]".encode("utf-8"))
toktypes.append(gguf.TokenType.UNUSED)
self.gguf_writer.add_tokenizer_model("rwkv")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
special_vocab.chat_template = "rwkv-world"
# hack: Add '\n\n' as the EOT token to make it chat normally
special_vocab._set_special_token("eot", 261)
special_vocab.add_to_gguf(self.gguf_writer)
def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int):
tokenizer_path = Path(sys.path[0]) / "models" / f"ggml-vocab-{model_name}.gguf"
logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
@@ -3322,6 +3359,83 @@ class Gemma2Model(Model):
return [(self.map_tensor_name(name), data_torch)]
@Model.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
class Gemma3Model(Model):
model_arch = gguf.MODEL_ARCH.GEMMA3
has_vision: bool = False
# we need to merge the text_config into the root level of hparams
def __init__(self, *args, **kwargs):
hparams = Model.load_hparams(kwargs["dir_model"])
if "text_config" in hparams:
hparams = {**hparams, **hparams["text_config"]}
kwargs["hparams"] = hparams
super().__init__(*args, **kwargs)
if "vision_config" in hparams:
logger.info("Has vision encoder, but it will be ignored")
self.has_vision = True
def write(self):
super().write()
if self.has_vision:
logger.info("NOTE: this script only convert the language model to GGUF")
logger.info(" for the vision model, please use gemma3_convert_encoder_to_gguf.py")
def set_vocab(self):
self._set_vocab_sentencepiece()
self.gguf_writer.add_add_space_prefix(False)
def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]
# some default values are not specified in the hparams
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072))
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8))
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6))
self.gguf_writer.add_key_length(hparams.get("head_dim", 256))
self.gguf_writer.add_value_length(hparams.get("head_dim", 256))
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
# both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3
assert hparams.get("attn_logit_softcapping") is None
assert hparams.get("final_logit_softcapping") is None
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
if hparams.get("rope_scaling") is not None:
assert hparams["rope_scaling"]["rope_type"] == "linear"
# important: this rope_scaling is only applied for global layers, and not used by 1B model
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
if name.startswith("language_model."):
name = name.replace("language_model.", "")
elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
or name.startswith("multimodal_projector.") or name.startswith("vision_model."): # this is for old HF model, should be removed later
# ignore vision tensors
return []
# remove OOV (out-of-vocabulary) rows in token_embd
if "embed_tokens.weight" in name:
vocab = self._create_vocab_sentencepiece()
tokens = vocab[0]
data_torch = data_torch[:len(tokens)]
# ref code in Gemma3RMSNorm
# output = output * (1.0 + self.weight.float())
if name.endswith("norm.weight"):
data_torch = data_torch + 1
return [(self.map_tensor_name(name), data_torch)]
@Model.register("Starcoder2ForCausalLM")
class StarCoder2Model(Model):
model_arch = gguf.MODEL_ARCH.STARCODER2
@@ -3332,38 +3446,7 @@ class Rwkv6Model(Model):
model_arch = gguf.MODEL_ARCH.RWKV6
def set_vocab(self):
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
vocab_size = self.hparams.get("vocab_size", 65536)
tokens: list[bytes] = ['<s>'.encode("utf-8")]
toktypes: list[int] = [gguf.TokenType.CONTROL]
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
parts = line.split(' ')
assert len(parts) >= 3
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
token = token.encode("utf-8") if isinstance(token, str) else token
assert isinstance(token, bytes)
assert len(token) == token_len
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
tokens.append(token_text.encode("utf-8"))
toktypes.append(gguf.TokenType.NORMAL)
remainder = vocab_size - len(tokens)
assert remainder >= 0
for i in range(len(tokens), vocab_size):
tokens.append(f"[PAD{i}]".encode("utf-8"))
toktypes.append(gguf.TokenType.UNUSED)
self.gguf_writer.add_tokenizer_model("rwkv")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
special_vocab.chat_template = "rwkv-world"
# hack: Add '\n\n' as the EOT token to make it chat normally
special_vocab._set_special_token("eot", 261)
special_vocab.add_to_gguf(self.gguf_writer)
self._set_vocab_rwkv_world()
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
@@ -3485,6 +3568,168 @@ class RWKV6Qwen2Model(Rwkv6Model):
yield (new_name, data)
@Model.register("Rwkv7ForCausalLM", "RWKV7ForCausalLM")
class Rwkv7Model(Model):
model_arch = gguf.MODEL_ARCH.RWKV7
def set_vocab(self):
self._set_vocab_rwkv_world()
def calc_lora_rank(self, hidden_size, exponent, multiplier):
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
try:
head_size = self.hparams["head_size"]
layer_norm_eps = self.hparams["layer_norm_epsilon"]
except KeyError:
head_size = self.hparams["head_dim"]
layer_norm_eps = self.hparams["norm_eps"]
hidden_size = self.hparams["hidden_size"]
intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else (hidden_size * 4)
# ICLR: In-Context-Learning-Rate
try:
lora_rank_decay = self.hparams["lora_rank_decay"] if self.hparams["lora_rank_decay"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
lora_rank_iclr = self.hparams["lora_rank_iclr"] if self.hparams["lora_rank_iclr"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
lora_rank_value_residual_mix = self.hparams["lora_rank_value_residual_mix"] if self.hparams["lora_rank_value_residual_mix"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
lora_rank_gate = self.hparams["lora_rank_gate"] if self.hparams["lora_rank_gate"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
except KeyError:
lora_rank_decay = self.hparams["decay_low_rank_dim"] if self.hparams["decay_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
lora_rank_iclr = self.hparams["a_low_rank_dim"] if self.hparams["a_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
lora_rank_value_residual_mix = self.hparams["v_low_rank_dim"] if self.hparams["v_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
lora_rank_gate = self.hparams["gate_low_rank_dim"] if self.hparams["gate_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
# RWKV isn't context limited
self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
self.gguf_writer.add_wkv_head_size(head_size)
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
self.gguf_writer.add_feed_forward_length(intermediate_size)
self.gguf_writer.add_file_type(self.ftype)
# required by llama.cpp, unused
self.gguf_writer.add_head_count(0)
lerp_weights: dict[int, dict[str, Tensor]] = {}
lora_needs_transpose: bool = True
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# unify tensor names here to make life easier
name = name.replace("blocks", "layers").replace("ffn", "feed_forward")
name = name.replace("self_attn", "attention").replace("attn", "attention")
name = name.replace("time_mixer.", "")
# lora layer names in fla-hub's impl
if "_lora.lora" in name:
self.lora_needs_transpose = False
name = name.replace("_lora.lora.0.weight", "1.weight")
name = name.replace("_lora.lora.2.weight", "2.weight")
name = name.replace("_lora.lora.2.bias", "0.weight")
name = name.replace("feed_forward_norm", "ln2")
name = name.replace("g_norm", "ln_x")
if "attention.v" in name and "value" not in self.map_tensor_name(name) and bid == 0:
# some models have dummy v0/v1/v2 on first layer while others don't
# ignore them all since they are not used
return
wkv_has_gate = self.hparams.get("wkv_has_gate", True)
lerp_list = ["r", "w", "k", "v", "a", "g"] if wkv_has_gate else ["r", "w", "k", "v", "a"]
if bid is not None and "attention.x_" in name:
if "attention.x_x" in name:
# already concatenated
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
data = data_torch.reshape(len(lerp_list), 1, 1, -1)
yield (new_name, data)
else:
try:
self.lerp_weights[bid][name] = data_torch
except KeyError:
self.lerp_weights[bid] = {name: data_torch}
if all(f"model.layers.{bid}.attention.x_{i}" in self.lerp_weights[bid].keys() for i in lerp_list):
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
data = torch.stack([self.lerp_weights[bid][f"model.layers.{bid}.attention.x_{i}"] for i in lerp_list], dim=0)
yield (new_name, data)
return
else:
data_torch = data_torch.squeeze()
new_name = self.map_tensor_name(name)
if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
new_name += ".weight"
if self.lora_needs_transpose and any(
new_name.endswith(t) for t in [
"time_mix_w1.weight", "time_mix_w2.weight",
"time_mix_a1.weight", "time_mix_a2.weight",
"time_mix_v1.weight", "time_mix_v2.weight",
"time_mix_g1.weight", "time_mix_g2.weight",
]
):
data_torch = data_torch.transpose(0, 1)
if 'r_k' in new_name:
data_torch = data_torch.flatten()
if bid == 0 and "time_mix_a" in new_name:
# dummy v0/v1/v2 on first layer
# easist way to make llama happy
yield (new_name.replace("time_mix_a", "time_mix_v"), data_torch)
yield (new_name, data_torch)
@Model.register("RwkvHybridForCausalLM")
class ARwkv7Model(Rwkv7Model):
model_arch = gguf.MODEL_ARCH.ARWKV7
def set_vocab(self):
try:
self._set_vocab_sentencepiece()
except FileNotFoundError:
self._set_vocab_gpt2()
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
hidden_size = self.hparams["hidden_size"]
head_size = self.hparams["head_size"]
rms_norm_eps = self.hparams["rms_norm_eps"]
intermediate_size = self.hparams["intermediate_size"]
wkv_has_gate = self.hparams["wkv_has_gate"]
assert self.hparams["wkv_version"] == 7
# ICLR: In-Context-Learning-Rate
lora_rank_decay = 64
lora_rank_iclr = 64
lora_rank_value_residual_mix = 32
lora_rank_gate = 128 if wkv_has_gate else 0
# RWKV isn't context limited
self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_wkv_head_size(head_size)
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
self.gguf_writer.add_feed_forward_length(intermediate_size)
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_token_shift_count(1)
# required by llama.cpp, unused
self.gguf_writer.add_head_count(0)
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
class MambaModel(Model):
model_arch = gguf.MODEL_ARCH.MAMBA

View File

@@ -660,8 +660,9 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|--------------------|---------------------------------------|---------------------------------------------|
| GGML_SYCL | ON (mandatory) | Enable build with SYCL code path.<br>FP32 path - recommended for better perforemance than FP16 on quantized model|
| GGML_SYCL_TARGET | INTEL *(default)* \| NVIDIA \| AMD | Set the SYCL target device type. |
| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. |
| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. |
| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. |
| GGML_SYCL_GRAPH | ON *(default)* \|OFF *(Optional)* | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). |
| CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. |
| CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. |
@@ -671,6 +672,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase |
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |

View File

@@ -59,24 +59,17 @@ int main(int argc, char ** argv) {
const int32_t n_kv_max = llama_n_ctx(ctx);
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
llama_batch_ext * batch = llama_batch_ext_init(n_kv_max, 1);
// decode in batches of ctx_params.n_batch tokens
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
auto decode_helper = [](llama_context * ctx, llama_batch_ext * batch, int32_t n_batch) {
const int32_t n_batch_tokens = llama_batch_ext_get_n_tokens(batch);
for (int32_t i = 0; i < (int32_t) n_batch_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (n_batch_tokens - i));
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
};
llama_batch_ext_ptr batch_view = llama_batch_ext_ptr(llama_batch_ext_get_view(batch, i, n_tokens));
const int ret = llama_decode(ctx, batch_view);
const int ret = llama_decode_ext(ctx, batch_view.get());
if (ret != 0) {
LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
return false;
@@ -91,7 +84,8 @@ int main(int argc, char ** argv) {
// warm up
{
for (int i = 0; i < 16; ++i) {
common_batch_add(batch, 0, i, { 0 }, false);
const llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, false);
}
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
@@ -121,18 +115,18 @@ int main(int argc, char ** argv) {
continue;
}
common_batch_clear(batch);
llama_batch_ext_clear(batch);
for (int i = 0; i < pp; ++i) {
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
common_batch_add(batch, 0, i, { j }, false);
llama_batch_ext_add_text(batch, 0, i, &j, 1, false);
}
}
batch.logits[batch.n_tokens - 1] = true;
llama_batch_ext_set_output_last(batch);
const auto t_pp_start = ggml_time_us();
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
@@ -141,7 +135,7 @@ int main(int argc, char ** argv) {
if (is_pp_shared) {
for (int32_t i = 1; i < pl; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
}
}
@@ -150,10 +144,10 @@ int main(int argc, char ** argv) {
const auto t_tg_start = ggml_time_us();
for (int i = 0; i < tg; ++i) {
common_batch_clear(batch);
llama_batch_ext_clear(batch);
for (int j = 0; j < pl; ++j) {
common_batch_add(batch, 0, pp + i, { j }, true);
llama_batch_ext_add_text(batch, 0, pp + i, &j, 1, true);
}
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
@@ -191,7 +185,7 @@ int main(int argc, char ** argv) {
LOG("\n");
llama_perf_context_print(ctx);
llama_batch_free(batch);
llama_batch_ext_free(batch);
llama_free(ctx);
llama_model_free(model);

View File

@@ -116,7 +116,7 @@ if llama_decode(context, batch) != 0 {
}
for i in 1 ..< n_parallel {
llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
llama_kv_self_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
}
if n_parallel > 1 {

View File

@@ -102,7 +102,7 @@ int main(int argc, char ** argv) {
// create a llama_batch
// we use this object to submit token data for decoding
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel);
llama_batch_ext * batch = llama_batch_ext_init(std::max(tokens_list.size(), (size_t) n_parallel), n_parallel);
std::vector<llama_seq_id> seq_ids(n_parallel, 0);
for (int32_t i = 0; i < n_parallel; ++i) {
@@ -111,12 +111,12 @@ int main(int argc, char ** argv) {
// evaluate the initial prompt
for (size_t i = 0; i < tokens_list.size(); ++i) {
common_batch_add(batch, tokens_list[i], i, seq_ids, false);
llama_batch_ext_add_text(batch, tokens_list[i], i, seq_ids.data(), seq_ids.size(), false);
}
GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());
GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) tokens_list.size());
if (llama_model_has_encoder(model)) {
if (llama_encode(ctx, batch)) {
if (llama_encode_ext(ctx, batch)) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
}
@@ -126,14 +126,14 @@ int main(int argc, char ** argv) {
decoder_start_token_id = llama_vocab_bos(vocab);
}
common_batch_clear(batch);
common_batch_add(batch, decoder_start_token_id, 0, seq_ids, false);
llama_batch_ext_clear(batch);
llama_batch_ext_add_text(batch, decoder_start_token_id, 0, seq_ids.data(), seq_ids.size(), false);
}
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
llama_batch_ext_set_output_last(batch);
if (llama_decode(ctx, batch) != 0) {
if (llama_decode_ext(ctx, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
}
@@ -155,16 +155,16 @@ int main(int argc, char ** argv) {
// remember the batch index of the last token for each parallel sequence
// we need this to determine which logits to sample from
std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
std::vector<int32_t> i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1);
int n_cur = batch.n_tokens;
int n_cur = llama_batch_ext_get_n_tokens(batch);
int n_decode = 0;
const auto t_main_start = ggml_time_us();
while (n_cur <= n_predict) {
// prepare the next batch
common_batch_clear(batch);
llama_batch_ext_clear(batch);
// sample the next token for each parallel sequence / stream
for (int32_t i = 0; i < n_parallel; ++i) {
@@ -193,23 +193,23 @@ int main(int argc, char ** argv) {
streams[i] += common_token_to_piece(ctx, new_token_id);
i_batch[i] = batch.n_tokens;
i_batch[i] = llama_batch_ext_get_n_tokens(batch);
// push this new token for next evaluation
common_batch_add(batch, new_token_id, n_cur, { i }, true);
llama_batch_ext_add_text(batch, new_token_id, n_cur, &i, 1, true);
n_decode += 1;
}
// all streams are finished
if (batch.n_tokens == 0) {
if (llama_batch_ext_get_n_tokens(batch) == 0) {
break;
}
n_cur += 1;
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch)) {
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
@@ -234,7 +234,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n");
llama_batch_free(batch);
llama_batch_ext_free(batch);
llama_sampler_free(smpl);
llama_free(ctx);

View File

@@ -342,8 +342,9 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
}
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
llama_kv_cache_clear(ctx);
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
llama_kv_self_clear(ctx);
auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), 0, 0, true);
if (llama_decode_ext(ctx, batch.get())) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}

View File

@@ -26,36 +26,36 @@ static std::vector<std::string> split_lines(const std::string & s, const std::st
return lines;
}
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
static void batch_add_seq(common_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
size_t n_tokens = tokens.size();
for (size_t i = 0; i < n_tokens; i++) {
common_batch_add(batch, tokens[i], i, { seq_id }, true);
batch.add_text(tokens[i], i, seq_id, true);
}
}
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
const struct llama_model * model = llama_get_model(ctx);
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
// run model
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq);
if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
// encoder-only model
if (llama_encode(ctx, batch) < 0) {
if (llama_encode_ext(ctx, batch.get()) < 0) {
LOG_ERR("%s : failed to encode\n", __func__);
}
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
// decoder-only model
if (llama_decode(ctx, batch) < 0) {
if (llama_decode_ext(ctx, batch.get()) < 0) {
LOG_ERR("%s : failed to decode\n", __func__);
}
}
for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) {
if (!batch.tokens[i].logits) {
continue;
}
@@ -69,8 +69,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
GGML_ASSERT(embd != NULL && "failed to get token embeddings");
} else {
// try to get sequence embeddings - supported only when pooling_type is not NONE
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
embd_pos = batch.seq_id[i][0];
embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id);
embd_pos = batch.tokens[i].seq_id;
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
}
@@ -171,7 +171,7 @@ int main(int argc, char ** argv) {
// initialize batch
const int n_prompts = prompts.size();
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
struct common_batch batch = common_batch(n_batch, 1);
// count number of embeddings
int n_embd_count = 0;
@@ -198,12 +198,12 @@ int main(int argc, char ** argv) {
const uint64_t n_toks = inp.size();
// encode if at capacity
if (batch.n_tokens + n_toks > n_batch) {
if (batch.get_n_tokens() + n_toks > n_batch) {
float * out = emb + e * n_embd;
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.get_n_tokens() : s;
s = 0;
common_batch_clear(batch);
batch.clear();
}
// add to batch
@@ -319,7 +319,6 @@ int main(int argc, char ** argv) {
llama_perf_context_print(ctx);
// clean up
llama_batch_free(batch);
llama_backend_free();
return 0;

View File

@@ -134,7 +134,8 @@ static bool run(llama_context * ctx, const common_params & params) {
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), 0, 0, true);
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}

View File

@@ -13,10 +13,10 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
llama_batch_ext * batch = llama_batch_ext_init(llama_n_batch(ctx), 1);
for (uint64_t i = 0; i < sentences.size(); i++) {
common_batch_clear(batch);
llama_batch_ext_clear(batch);
const std::string input_string = instruction + sentences[i];
@@ -41,16 +41,17 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
// add input to batch (this increments n_tokens)
for (int32_t j = 0; j < n_toks; j++) {
common_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
const llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, inputs[j], j, &seq_id, 1 , j >= n_inst);
}
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
llama_set_embeddings(ctx, true);
llama_set_causal_attn(ctx, false);
// run model
llama_decode(ctx, batch);
llama_decode_ext(ctx, batch);
// get embedding dimensions
uint64_t n_embd = llama_model_n_embd(model);
@@ -89,7 +90,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
#endif
}
llama_batch_free(batch);
llama_batch_ext_free(batch);
return result;
}
@@ -102,29 +103,30 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
llama_token eos_token = llama_vocab_eos(vocab);
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
llama_set_embeddings(ctx, false);
llama_set_causal_attn(ctx, true);
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
llama_batch_ext * bat = llama_batch_ext_init(llama_n_batch(ctx), 1);
std::vector<llama_token> inputs = common_tokenize(vocab, prompt, false, true);
int32_t i_current_token = 0;
while (true) {
common_batch_clear(bat);
llama_batch_ext_clear(bat);
{
const int32_t n_inputs = inputs.size();
for (int32_t i = 0; i < n_inputs; i++) {
common_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
const llama_seq_id seq_id = 0;
llama_batch_ext_add_text(bat, inputs[i], i_current_token++, &seq_id, 1, i == n_inputs - 1);
}
}
inputs.clear();
llama_decode(ctx, bat);
llama_decode_ext(ctx, bat);
llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
llama_token token = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(bat) - 1);
if (token == eos_token) {
break;
@@ -145,7 +147,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
std::printf("\n");
}
llama_batch_free(bat);
llama_batch_ext_free(bat);
return result;
}

View File

@@ -495,9 +495,9 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
llama_batch batch = llama_batch_init(n_batch, 0, 1);
llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
@@ -511,14 +511,15 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
tokens[batch_start] = llama_vocab_bos(vocab);
}
common_batch_clear(batch);
llama_batch_ext_clear(batch);
for (int i = 0; i < batch_size; i++) {
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
const llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, tokens[batch_start + i], j*n_batch + i, &seq_id, 1, true);
}
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch)) {
LOG_ERR("%s : failed to eval\n", __func__);
llama_batch_free(batch);
llama_batch_ext_free(batch);
return false;
}
@@ -531,7 +532,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
}
}
llama_batch_free(batch);
llama_batch_ext_free(batch);
const auto t_end = std::chrono::high_resolution_clock::now();

View File

@@ -332,8 +332,8 @@ int main(int argc, char ** argv) {
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
llama_kv_self_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_self_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
n_past -= n_discard;
@@ -353,7 +353,8 @@ int main(int argc, char ** argv) {
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
auto batch = llama_batch_ext_ptr::init_from_text(&embd[i], n_eval, n_past, 0, true);
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
}

View File

@@ -1427,7 +1427,7 @@ struct sql_printer : public printer {
}
};
static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads);
const llama_model * model = llama_get_model(ctx);
@@ -1444,14 +1444,15 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
for (int i = 1; i < n_tokens; i++) {
tokens[i] = std::rand() % n_vocab;
}
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), n_tokens, n_past + n_processed, 0, true);
llama_decode_ext(ctx, batch.get());
n_processed += n_tokens;
}
llama_synchronize(ctx);
}
static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads);
const llama_model * model = llama_get_model(ctx);
@@ -1461,7 +1462,8 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
for (int i = 0; i < n_gen; i++) {
llama_decode(ctx, llama_batch_get_one(&token, 1));
auto batch = llama_batch_ext_ptr::init_from_text(&token, 1, n_past + i, 0, true);
llama_decode_ext(ctx, batch.get());
llama_synchronize(ctx);
token = std::rand() % n_vocab;
}
@@ -1578,7 +1580,7 @@ int main(int argc, char ** argv) {
test t(inst, lmodel, ctx);
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
// cool off before the test
if (params.delay) {
@@ -1608,17 +1610,17 @@ int main(int argc, char ** argv) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count);
}
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
}
if (t.n_gen > 0) {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup generation run\n", params_idx, params_count);
}
test_gen(ctx, 1, t.n_threads);
test_gen(ctx, 1, 0, t.n_threads);
}
for (int i = 0; i < params.reps; i++) {
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
uint64_t t_start = get_time_ns();
@@ -1627,14 +1629,14 @@ int main(int argc, char ** argv) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: prompt run %d/%d\n", params_idx, params_count,
i + 1, params.reps);
}
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
}
if (t.n_gen > 0) {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count,
i + 1, params.reps);
}
test_gen(ctx, t.n_gen, t.n_threads);
test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
}
uint64_t t_ns = get_time_ns() - t_start;

View File

@@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
}
batch->logits[batch->n_tokens - 1] = true;
llama_kv_cache_clear(context);
llama_kv_self_clear(context);
const auto t_pp_start = ggml_time_us();
if (llama_decode(context, *batch) != 0) {
@@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
LOGi("Benchmark text generation (tg)");
llama_kv_cache_clear(context);
llama_kv_self_clear(context);
const auto t_tg_start = ggml_time_us();
for (i = 0; i < tg; i++) {
@@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
const auto t_tg_end = ggml_time_us();
llama_kv_cache_clear(context);
llama_kv_self_clear(context);
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
@@ -448,5 +448,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
llama_kv_self_clear(reinterpret_cast<llama_context *>(context));
}

View File

@@ -16,7 +16,7 @@ Open `llama.swiftui.xcodeproj` project in Xcode and you should be able to build
a simulator or a real device.
To use the framework with a different project, the XCFramework can be added to the project by
adding `build-ios/llama.xcframework` by dragging and dropping it into the project navigator, or
adding `build-apple/llama.xcframework` by dragging and dropping it into the project navigator, or
by manually selecting the framework in the "Frameworks, Libraries, and Embedded Content" section
of the project settings.

View File

@@ -210,7 +210,7 @@ actor LlamaContext {
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
llama_kv_cache_clear(context)
llama_kv_self_clear(context)
let t_pp_start = DispatchTime.now().uptimeNanoseconds / 1000;
@@ -223,7 +223,7 @@ actor LlamaContext {
// bench text generation
llama_kv_cache_clear(context)
llama_kv_self_clear(context)
let t_tg_start = DispatchTime.now().uptimeNanoseconds / 1000;
@@ -242,7 +242,7 @@ actor LlamaContext {
let t_tg_end = DispatchTime.now().uptimeNanoseconds / 1000;
llama_kv_cache_clear(context)
llama_kv_self_clear(context)
let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
@@ -292,7 +292,7 @@ actor LlamaContext {
func clear() {
tokens_list.removeAll()
temporary_invalid_cchars.removeAll()
llama_kv_cache_clear(context)
llama_kv_self_clear(context)
}
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {

View File

@@ -51,6 +51,13 @@ install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)
set(TARGET llama-gemma3-cli)
add_executable(${TARGET} gemma3-cli.cpp)
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-gemma3-cli)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)
set(TARGET llama-llava-clip-quantize-cli)
add_executable(${TARGET} clip-quantize-cli.cpp)
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-clip-quantize-cli)

View File

@@ -0,0 +1,30 @@
# Gemma 3 vision
> [!IMPORTANT]
>
> This is very experimental, only used for demo purpose.
## How to get mmproj.gguf?
```bash
cd gemma-3-4b-it
python ../llama.cpp/examples/llava/gemma3_convert_encoder_to_gguf.py .
# output file is mmproj.gguf
```
## How to run it?
What you need:
- The text model GGUF, can be converted using `convert_hf_to_gguf.py`
- The mmproj file from step above
- An image file
```bash
# build
cmake -B build
cmake --build build --target llama-gemma3-cli
# run it
./build/bin/llama-gemma3-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg
```

View File

@@ -136,6 +136,8 @@ static std::string format(const char * fmt, ...) {
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
#define TN_IMAGE_NEWLINE "model.image_newline"
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
#define TN_MINICPMV_QUERY "resampler.query"
@@ -162,6 +164,7 @@ enum projector_type {
PROJECTOR_TYPE_RESAMPLER,
PROJECTOR_TYPE_GLM_EDGE,
PROJECTOR_TYPE_MERGER,
PROJECTOR_TYPE_GEMMA3,
PROJECTOR_TYPE_UNKNOWN,
};
@@ -172,6 +175,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_RESAMPLER, "resampler"},
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
};
@@ -298,7 +302,7 @@ static projector_type clip_projector_type_from_string(const std::string & name)
return kv.first;
}
}
return PROJECTOR_TYPE_UNKNOWN;
throw std::runtime_error(format("Unknown projector type: %s", name.c_str()));
}
#ifdef CLIP_DEBUG_FUNCTIONS
@@ -555,6 +559,10 @@ struct clip_vision_model {
struct ggml_tensor * mm_model_ln_kv_b;
struct ggml_tensor * mm_model_ln_post_w;
struct ggml_tensor * mm_model_ln_post_b;
// gemma3
struct ggml_tensor * mm_input_proj_w;
struct ggml_tensor * mm_soft_emb_norm_w;
};
struct clip_ctx {
@@ -569,7 +577,7 @@ struct clip_ctx {
struct clip_vision_model vision_model;
projector_type proj_type = PROJECTOR_TYPE_MLP;
int32_t max_feature_layer;
int32_t max_feature_layer; // unused in newer models like gemma3
float image_mean[3];
float image_std[3];
bool use_gelu = false;
@@ -595,7 +603,7 @@ struct clip_ctx {
ggml_backend_sched_ptr sched;
struct clip_image_size * load_image_size;
struct clip_image_size * load_image_size = nullptr;
clip_ctx(clip_context_params & ctx_params) {
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
@@ -631,7 +639,159 @@ struct clip_ctx {
}
};
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch * imgs) {
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
const int image_size = hparams.image_size;
int image_size_width = image_size;
int image_size_height = image_size;
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head;
const int n_layer = hparams.n_layer;
const float eps = hparams.eps;
GGML_ASSERT(imgs->size == 1); // batch_size == 1
struct ggml_init_params params = {
/*.mem_size =*/ ctx->buf_compute_meta.size(),
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
/*.no_alloc =*/ true,
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
// input raw
struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3);
ggml_set_name(inp_raw, "inp_raw");
ggml_set_input(inp_raw);
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
inp = ggml_add(ctx0, inp, model.patch_bias);
// position embeddings
struct ggml_tensor * embeddings = ggml_add(ctx0, inp, model.position_embeddings);
// loop over layers
for (int il = 0; il < n_layer; il++) {
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
// layernorm1
{
cur = ggml_norm(ctx0, cur, eps);
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), model.layers[il].ln_1_b);
}
// self-attention
{
struct ggml_tensor * Q =
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
struct ggml_tensor * K =
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
struct ggml_tensor * V =
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
KQ = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf((float)d_head));
KQ = ggml_soft_max_inplace(ctx0, KQ);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
}
// attention output
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
// re-add the layer input, e.g., residual
cur = ggml_add(ctx0, cur, embeddings);
embeddings = cur; // embeddings = residual, cur = hidden_states
// layernorm2
{
cur = ggml_norm(ctx0, cur, eps);
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
}
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
// siglip uses gelu
cur = ggml_gelu(ctx0, cur);
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
// residual 2
cur = ggml_add(ctx0, embeddings, cur);
embeddings = cur;
}
// post-layernorm
if (ctx->has_post_norm) {
embeddings = ggml_norm(ctx0, embeddings, eps);
ggml_set_name(embeddings, "post_ln");
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
}
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
const int batch_size = 1;
const int mm_tokens_per_image = 256; // default value for gemma3
const int tokens_per_side = sqrt(mm_tokens_per_image);
const int patches_per_image = sqrt(num_patches);
const int kernel_size = patches_per_image / tokens_per_side;
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, hidden_size, batch_size);
// doing a pool2d to reduce the number of output tokens to 256
embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], hidden_size, batch_size);
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
// apply norm before projection
embeddings = ggml_rms_norm(ctx0, embeddings, eps);
embeddings = ggml_mul(ctx0, embeddings, model.mm_soft_emb_norm_w);
// apply projection
embeddings = ggml_mul_mat(ctx0,
ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
embeddings);
}
// build the graph
ggml_build_forward_expand(gf, embeddings);
ggml_free(ctx0);
return gf;
}
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
if (!ctx->has_vision_encoder) {
LOG_ERR("This gguf file seems to have no vision encoder\n");
return nullptr;
@@ -1177,7 +1337,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
} else {
GGML_ABORT("fatel error");
}
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
}
else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
@@ -1199,6 +1360,15 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
return gf;
}
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
return clip_image_build_graph_siglip(ctx, imgs);
} else {
// TODO: we should have one build_* function per model
return clip_image_build_graph_legacy(ctx, imgs, load_image_size, is_inf);
}
}
// read and create ggml_context containing the tensors and their data
struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
return clip_init(fname, clip_context_params{
@@ -1358,8 +1528,12 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
GGML_ASSERT(new_clip->has_vision_encoder);
GGML_ASSERT(!new_clip->has_text_encoder);
idx = get_key_idx(ctx, KEY_USE_GELU);
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
try {
idx = get_key_idx(ctx, KEY_USE_GELU);
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
} catch (std::runtime_error & /*e*/) {
new_clip->use_gelu = false;
}
try {
idx = get_key_idx(ctx, KEY_USE_SILU);
@@ -1567,11 +1741,17 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
}
try {
vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
} catch(const std::exception& /*e*/) {
vision_model.patch_embeddings_0 = nullptr;
}
try {
vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v"));
} catch(const std::exception& /*e*/) {
LOG_ERR("%s: failed to load vision model tensors\n", __func__);
vision_model.position_embeddings = nullptr;
}
try {
vision_model.patch_embeddings_1 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD_1);
} catch(const std::exception& /*e*/) {
@@ -1682,6 +1862,10 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight"));
vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias"));
}
else if (new_clip->proj_type == PROJECTOR_TYPE_GEMMA3) {
vision_model.mm_input_proj_w = get_tensor(new_clip->ctx_data, TN_MM_INP_PROJ);
vision_model.mm_soft_emb_norm_w = get_tensor(new_clip->ctx_data, TN_MM_SOFT_EMB_N);
}
else {
std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type];
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
@@ -2223,7 +2407,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
return true;
}
if (ctx->has_glm_projector) {
if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
res_imgs->size = 1;
res_imgs->data = new clip_image_f32[res_imgs->size];
clip_image_u8 resized_image;
@@ -2748,6 +2932,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
free(positions_data);
}
else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
// do nothing
}
else {
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
@@ -2960,6 +3147,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
return ctx->vision_model.mm_1_b->ne[0];
}
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
return ctx->vision_model.mm_input_proj_w->ne[0];
}
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));

View File

@@ -0,0 +1,311 @@
#include "arg.h"
#include "log.h"
#include "common.h"
#include "sampling.h"
#include "clip.h"
#include "stb_image.h"
#include "llama.h"
#include "llama-cpp.h"
#include "ggml.h"
#include "console.h"
#include <vector>
#include <limits.h>
#include <inttypes.h>
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h>
#include <unistd.h>
#elif defined (_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#include <signal.h>
#endif
static bool g_is_generating = false;
/**
* Please note that this is NOT a production-ready stuff.
* It is a playground for trying Gemma 3 vision capabilities.
* For contributors: please keep this code simple and easy to understand.
*/
static void show_additional_info(int /*argc*/, char ** argv) {
LOG(
"Experimental CLI for using Gemma 3 vision model\n\n"
"Usage: %s [options] -m <model> --mmproj <mmproj> --image <image> -p <prompt>\n\n"
" -m and --mmproj are required\n"
" --image and -p are optional, if NOT provided, the CLI will run in chat mode\n",
argv[0]
);
}
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
static void sigint_handler(int signo) {
if (signo == SIGINT) {
if (g_is_generating) {
g_is_generating = false;
} else {
console::cleanup();
LOG("\nInterrupted by user\n");
_exit(130);
}
}
}
#endif
struct gemma3_context {
struct clip_ctx * ctx_clip = NULL;
common_init_result llama_init;
llama_model * model;
llama_context * lctx;
const llama_vocab * vocab;
llama_batch_ext_ptr batch;
int n_threads = 1;
llama_pos n_past = 0;
gemma3_context(common_params & params) : llama_init(common_init_from_params(params)) {
model = llama_init.model.get();
lctx = llama_init.context.get();
vocab = llama_model_get_vocab(model);
n_threads = params.cpuparams.n_threads;
batch.reset(llama_batch_ext_init(params.n_batch, 1));
init_clip_model(params);
}
void init_clip_model(common_params & params) {
const char * clip_path = params.mmproj.c_str();
ctx_clip = clip_model_load(clip_path, params.verbosity > 1);
}
~gemma3_context() {
clip_free(ctx_clip);
}
};
static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) {
llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
llama_batch_ext_clear(ctx.batch.get());
for (llama_token & t : tokens) {
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(ctx.batch.get(), t, ctx.n_past++, &seq_id, 1, false);
}
if (logits_last) {
llama_batch_ext_set_output_last(ctx.batch.get());
}
// LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
if (llama_decode_ext(ctx.lctx, ctx.batch.get())) {
LOG_ERR("Failed to decode text\n");
return 1;
}
return 0;
}
static int eval_image(gemma3_context & ctx, std::string & fname) {
std::vector<float> image_embd_v;
int n_embd = llama_model_n_embd(ctx.model);
int n_tokens = 256;
image_embd_v.resize(n_tokens * n_embd);
bool ok;
struct clip_image_u8 * img_u8 = clip_image_u8_init();
ok = clip_image_load_from_file(fname.c_str(), img_u8);
if (!ok) {
LOG_ERR("Unable to load image %s\n", fname.c_str());
clip_image_u8_free(img_u8);
return 2; // non-fatal error
}
clip_image_f32_batch batch_f32;
ok = clip_image_preprocess(ctx.ctx_clip, img_u8, &batch_f32);
if (!ok) {
LOG_ERR("Unable to preprocess image\n");
clip_image_f32_batch_free(&batch_f32);
clip_image_u8_free(img_u8);
return 1;
}
int64_t t0 = ggml_time_ms();
LOG("Encoding image %s\n", fname.c_str());
ok = clip_image_batch_encode(ctx.ctx_clip, ctx.n_threads, &batch_f32, image_embd_v.data());
if (!ok) {
LOG_ERR("Unable to encode image\n");
clip_image_f32_batch_free(&batch_f32);
clip_image_u8_free(img_u8);
return 1;
}
LOG("Image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
clip_image_f32_batch_free(&batch_f32);
clip_image_u8_free(img_u8);
// decode image embeddings
int64_t t1 = ggml_time_ms();
eval_text(ctx, "<start_of_image>");
llama_set_causal_attn(ctx.lctx, false);
llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, n_embd, ctx.n_past, 0));
if (llama_decode_ext(ctx.lctx, batch_img.get())) {
LOG_ERR("failed to decode image\n");
return 1;
}
ctx.n_past += n_tokens;
llama_set_causal_attn(ctx.lctx, true);
eval_text(ctx, "<end_of_image>");
LOG("Image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1);
return 0;
}
static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_predict) {
for (int i = 0; i < n_predict; i++) {
if (i > n_predict || !g_is_generating) {
printf("\n");
break;
}
llama_token token_id = common_sampler_sample(smpl, ctx.lctx, -1);
common_sampler_accept(smpl, token_id, true);
if (llama_vocab_is_eog(ctx.vocab, token_id)) {
printf("\n");
break; // end of generation
}
printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
fflush(stdout);
// eval the token
llama_batch_ext_clear(ctx.batch.get());
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(ctx.batch.get(), token_id, ctx.n_past++, &seq_id, 1, true);
if (llama_decode_ext(ctx.lctx, ctx.batch.get())) {
LOG_ERR("failed to decode token\n");
return 1;
}
}
return 0;
}
int main(int argc, char ** argv) {
ggml_time_init();
common_params params;
params.sampling.temp = 0.2; // lower temp by default for better quality
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) {
return 1;
}
common_init();
if (params.mmproj.empty()) {
show_additional_info(argc, argv);
return 1;
}
gemma3_context ctx(params);
printf("%s: %s\n", __func__, params.model.c_str());
bool is_single_turn = !params.prompt.empty() && !params.image.empty();
struct common_sampler * smpl = common_sampler_init(ctx.model, params.sampling);
int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict;
// ctrl+C handling
{
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = sigint_handler;
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
#elif defined (_WIN32)
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
};
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
}
if (eval_text(ctx, "<bos>")) {
return 1;
}
if (is_single_turn) {
g_is_generating = true;
if (eval_text(ctx, "<start_of_turn>user\n")) {
return 1;
}
for (auto & fname : params.image) {
if (eval_image(ctx, fname)) {
return 1;
}
}
if (eval_text(ctx, params.prompt + "<end_of_turn><start_of_turn>model\n", true)) {
return 1;
}
if (generate_response(ctx, smpl, n_predict)) {
return 1;
}
} else {
LOG("\n Running in chat mode, available commands:");
LOG("\n /image <path> load an image");
LOG("\n /clear clear the chat history");
LOG("\n /quit or /exit exit the program");
LOG("\n");
if (eval_text(ctx, "<start_of_turn>user\n")) {
return 1;
}
while (true) {
g_is_generating = false;
LOG("\n> ");
console::set_display(console::user_input);
std::string line;
console::readline(line, false);
console::set_display(console::reset);
line = string_strip(line);
if (line.empty()) {
continue;
}
if (line == "/quit" || line == "/exit") {
break;
}
if (line == "/clear") {
ctx.n_past = 0;
llama_kv_self_seq_rm(ctx.lctx, 0, 1, -1); // keep BOS
LOG("Chat history cleared\n\n");
continue;
}
g_is_generating = true;
if (line.find("/image") == 0) {
std::string image = line.substr(7);
int res = eval_image(ctx, image);
if (res == 2) {
continue; // image not found
}
if (res) {
return 1;
}
continue;
}
if (eval_text(ctx, line + "<end_of_turn><start_of_turn>model\n", true)) {
return 1;
}
if (generate_response(ctx, smpl, n_predict)) {
return 1;
}
if (eval_text(ctx, "<end_of_turn><start_of_turn>user\n")) {
return 1;
}
}
}
return 0;
}

View File

@@ -0,0 +1,307 @@
import gguf
import argparse
import logging
import sys
import torch
import json
import os
import numpy as np
from typing import cast, ContextManager, Any, Iterator
from pathlib import Path
from torch import Tensor
logger = logging.getLogger("gemma3-mmproj")
# (copied from convert_hf_to_gguf.py)
# tree of lazy tensors
class LazyTorchTensor(gguf.LazyBase):
_tensor_type = torch.Tensor
# to keep the type-checker happy
dtype: torch.dtype
shape: torch.Size
# only used when converting a torch.Tensor to a np.ndarray
_dtype_map: dict[torch.dtype, type] = {
torch.float16: np.float16,
torch.float32: np.float32,
}
# used for safetensors slices
# ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
# TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
_dtype_str_map: dict[str, torch.dtype] = {
"F64": torch.float64,
"F32": torch.float32,
"BF16": torch.bfloat16,
"F16": torch.float16,
# "U64": torch.uint64,
"I64": torch.int64,
# "U32": torch.uint32,
"I32": torch.int32,
# "U16": torch.uint16,
"I16": torch.int16,
"U8": torch.uint8,
"I8": torch.int8,
"BOOL": torch.bool,
"F8_E4M3": torch.float8_e4m3fn,
"F8_E5M2": torch.float8_e5m2,
}
def numpy(self) -> gguf.LazyNumpyTensor:
dtype = self._dtype_map[self.dtype]
return gguf.LazyNumpyTensor(
meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
args=(self,),
func=(lambda s: s.numpy())
)
@classmethod
def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor:
return torch.empty(size=shape, dtype=dtype, device="meta")
@classmethod
def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
dtype = cls._dtype_str_map[st_slice.get_dtype()]
shape: tuple[int, ...] = tuple(st_slice.get_shape())
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
return cast(torch.Tensor, lazy)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
del types # unused
if kwargs is None:
kwargs = {}
if func is torch.Tensor.numpy:
return args[0].numpy()
return cls._wrap_fn(func)(*args, **kwargs)
class Gemma3VisionTower:
hparams: dict
gguf_writer: gguf.GGUFWriter
fname_out: Path
ftype: gguf.LlamaFileType
@staticmethod
def load_hparams(dir_model: Path):
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
return json.load(f)
@staticmethod
def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]:
part_names: list[str] = []
for filename in os.listdir(dir_model):
if filename.startswith(prefix) and filename.endswith(suffix):
part_names.append(filename)
part_names.sort()
return part_names
def __init__(self,
dir_model: Path,
fname_out: Path,
ftype: gguf.LlamaFileType,
is_big_endian: bool,):
hparams = Gemma3VisionTower.load_hparams(dir_model)
self.hparams = hparams
self.fname_out = fname_out
self.ftype = ftype
endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
self.gguf_writer = gguf.GGUFWriter(path=None, arch="clip", endianess=endianess)
text_config = hparams["text_config"]
vision_config = hparams["vision_config"]
assert hparams["architectures"][0] == "Gemma3ForConditionalGeneration"
assert text_config is not None
assert vision_config is not None
self.gguf_writer.add_string ("clip.projector_type", "gemma3")
self.gguf_writer.add_bool ("clip.has_text_encoder", False)
self.gguf_writer.add_bool ("clip.has_vision_encoder", True)
self.gguf_writer.add_bool ("clip.has_llava_projector", False) # legacy
self.gguf_writer.add_uint32 ("clip.vision.image_size", vision_config["image_size"])
self.gguf_writer.add_uint32 ("clip.vision.patch_size", vision_config["patch_size"])
self.gguf_writer.add_uint32 ("clip.vision.embedding_length", vision_config["hidden_size"])
self.gguf_writer.add_uint32 ("clip.vision.feed_forward_length", vision_config["intermediate_size"])
self.gguf_writer.add_uint32 ("clip.vision.projection_dim", text_config["hidden_size"])
self.gguf_writer.add_uint32 ("clip.vision.block_count", vision_config["num_hidden_layers"])
self.gguf_writer.add_uint32 ("clip.vision.attention.head_count", vision_config["num_attention_heads"])
self.gguf_writer.add_float32("clip.vision.attention.layer_norm_epsilon", vision_config.get("layer_norm_eps", 1e-6))
# default values taken from HF tranformers code
self.gguf_writer.add_array ("clip.vision.image_mean", [0.5, 0.5, 0.5])
self.gguf_writer.add_array ("clip.vision.image_std", [0.5, 0.5, 0.5])
self.gguf_writer.add_bool ("clip.use_gelu", True)
# load tensors
for name, data_torch in self.get_tensors(dir_model):
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)
self.add_tensor(name, data_torch)
def get_tensors(self, dir_model: Path) -> Iterator[tuple[str, Tensor]]:
part_names = Gemma3VisionTower.get_model_part_names(dir_model, "model", ".safetensors")
tensor_names_from_parts: set[str] = set()
for part_name in part_names:
logger.info(f"gguf: loading model part '{part_name}'")
from safetensors import safe_open
ctx = cast(ContextManager[Any], safe_open(dir_model / part_name, framework="pt", device="cpu"))
with ctx as model_part:
tensor_names_from_parts.update(model_part.keys())
for name in model_part.keys():
data = model_part.get_slice(name)
data = LazyTorchTensor.from_safetensors_slice(data)
yield name, data
def add_tensor(self, name: str, data_torch: Tensor):
is_1d = len(data_torch.shape) == 1
is_embd = ".embeddings." in name
old_dtype = data_torch.dtype
can_quantize = not is_1d and not is_embd
data_qtype = gguf.GGMLQuantizationType.F32
# this is to support old checkpoint
# TODO: remove this when we have the final model
name = name.replace("vision_model.vision_model.", "vision_tower.vision_model.")
name = name.replace("multimodal_projector.", "multi_modal_projector.")
# filter only vision tensors
if not name.startswith("vision_tower.vision_model.") and not name.startswith("multi_modal_projector."):
return
# prefix
name = name.replace("vision_tower.vision_model.encoder.layers.", "v.blk.")
name = name.replace("vision_tower.vision_model.", "v.")
# projector and input embd
name = name.replace(".embeddings.patch_embedding.", ".patch_embd.")
name = name.replace(".embeddings.position_embedding.", ".position_embd.")
name = name.replace(
"multi_modal_projector.mm_input_projection_weight",
"mm.input_projection.weight"
)
name = name.replace(
"multi_modal_projector.mm_soft_emb_norm.weight",
"mm.soft_emb_norm.weight"
)
name = name.replace("post_layernorm.", "post_ln.")
# each block
name = name.replace(".self_attn.k_proj.", ".attn_k.")
name = name.replace(".self_attn.v_proj.", ".attn_v.")
name = name.replace(".self_attn.q_proj.", ".attn_q.")
name = name.replace(".self_attn.out_proj.", ".attn_out.")
name = name.replace(".layer_norm1.", ".ln1.")
name = name.replace(".layer_norm2.", ".ln2.")
name = name.replace(".mlp.fc1.", ".ffn_down.")
name = name.replace(".mlp.fc2.", ".ffn_up.")
if can_quantize:
if self.ftype == gguf.LlamaFileType.ALL_F32:
data_qtype = gguf.GGMLQuantizationType.F32
elif self.ftype == gguf.LlamaFileType.MOSTLY_F16:
data_qtype = gguf.GGMLQuantizationType.F16
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
data_qtype = gguf.GGMLQuantizationType.BF16
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
data_qtype = gguf.GGMLQuantizationType.Q8_0
else:
raise ValueError(f"Unsupported file type: {self.ftype}")
# corrent norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector
# the other norm values are part of SigLIP model, and they are already correct
# ref code: Gemma3RMSNorm
if "soft_emb_norm.weight" in name:
logger.info(f"Correcting norm value for '{name}'")
data_torch = data_torch + 1
data = data_torch.numpy()
try:
data = gguf.quants.quantize(data, data_qtype)
except Exception as e:
logger.error(f"Error quantizing tensor '{name}': {e}, fallback to F16")
data_qtype = gguf.GGMLQuantizationType.F16
data = gguf.quants.quantize(data, data_qtype)
# reverse shape to make it similar to the internal ggml dimension order
shape_str = f"{{{', '.join(str(n) for n in reversed(data_torch.shape))}}}"
logger.info(f"{f'%-32s' % f'{name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype)
def write(self):
self.gguf_writer.write_header_to_file(path=self.fname_out)
self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.write_tensors_to_file(progress=True)
self.gguf_writer.close()
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Convert Gemma 3 vision tower safetensors to GGUF format",)
parser.add_argument(
"--outfile", type=Path, default="mmproj.gguf",
help="path to write to",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16",
help="output format",
)
parser.add_argument(
"--bigendian", action="store_true",
help="model is executed on big endian machine",
)
parser.add_argument(
"model", type=Path,
help="directory containing model file",
nargs="?",
)
parser.add_argument(
"--verbose", action="store_true",
help="increase output verbosity",
)
args = parser.parse_args()
if args.model is None:
parser.error("the following arguments are required: model")
return args
def main() -> None:
args = parse_args()
if args.verbose:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
dir_model = args.model
if not dir_model.is_dir():
logger.error(f'Error: {args.model} is not a directory')
sys.exit(1)
ftype_map: dict[str, gguf.LlamaFileType] = {
"f32": gguf.LlamaFileType.ALL_F32,
"f16": gguf.LlamaFileType.MOSTLY_F16,
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
}
logger.info(f"Loading model: {dir_model.name}")
with torch.inference_mode():
gemma3_vision_tower = Gemma3VisionTower(
dir_model=dir_model,
fname_out=args.outfile,
ftype=ftype_map[args.outtype],
is_big_endian=args.bigendian,
)
gemma3_vision_tower.write()
if __name__ == '__main__':
main()

View File

@@ -20,7 +20,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
if (n_eval > n_batch) {
n_eval = n_batch;
}
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
auto batch = llama_batch_ext_ptr::init_from_text(&tokens[i], n_eval, *n_past, 0, true);
if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false;
}

View File

@@ -2,6 +2,7 @@
#include "llava.h"
#include "llama.h"
#include "llama-cpp.h"
#include <algorithm>
#include <cerrno>
@@ -438,39 +439,6 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
return true;
}
struct llava_embd_batch {
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id> seq_id_0;
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
pos .resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
logits .resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = pos_0 + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
}
};
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
int n_embd = llama_model_n_embd(llama_get_model(ctx_llama));
@@ -480,8 +448,8 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
n_eval = n_batch;
}
float * embd = image_embed->embed+i*n_embd;
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
if (llama_decode(ctx_llama, llava_batch.batch)) {
auto batch = llama_batch_ext_ptr::init_from_embd(embd, n_eval, n_embd, 0, 0);
if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}

View File

@@ -101,7 +101,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
if (n_eval > n_batch) {
n_eval = n_batch;
}
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
auto batch = llama_batch_ext_ptr::init_from_text(&tokens[i], n_eval, *n_past, 0, true);
if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false;
}

View File

@@ -66,17 +66,11 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos));
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
llama_batch batch = {
int32_t(n_eval), // n_tokens
nullptr, // token
(image_embed->embed+i*n_embd), // embed
batch_mrope_pos.data(), // pos
nullptr, // n_seq_id
nullptr, // seq_id
nullptr, // logits
};
float * batch_embd = image_embed->embed+i*n_embd;
auto batch = llama_batch_ext_ptr::init_from_embd(batch_embd, n_eval, n_embd, 0, 0);
llama_batch_ext_set_pos(batch.get(), batch_mrope_pos.data(), n_eval);
if (llama_decode(ctx_llama, batch)) {
if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}
@@ -95,16 +89,24 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
if (n_eval > n_batch) {
n_eval = n_batch;
}
auto batch = llama_batch_get_one(&tokens[i], n_eval);
// TODO: add mrope pos ids somewhere else
pos.resize(batch.n_tokens * 4);
std::fill(pos.begin(), pos.end(), 0);
for (int j = 0; j < batch.n_tokens * 3; j ++) {
pos[j] = *st_pos_id + (j % batch.n_tokens);
}
batch.pos = pos.data();
if (llama_decode(ctx_llama, batch)) {
// TODO: add mrope pos ids somewhere else
int n_tokens = n_eval;
pos.resize(n_tokens * 4);
std::fill(pos.begin(), pos.end(), 0);
for (int j = 0; j < n_tokens * 3; j ++) {
pos[j] = *st_pos_id + (j % n_tokens);
}
llama_batch_ext_ptr batch(llama_batch_ext_init(n_eval, 1));
for (int j = 0; j < n_eval; j++) {
llama_token token = tokens[i + j];
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch.get(), token, pos[j], &seq_id, 1, false);
}
llama_batch_ext_set_output_last(batch.get());
if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false;
}

View File

@@ -92,11 +92,13 @@ int main(int argc, char ** argv) {
const auto t_enc_start = ggml_time_us();
// eval the prompt
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1));
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
llama_decode_ext(ctx, batch0.get());
llama_decode_ext(ctx, batch1.get());
for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
}
const auto t_enc_end = ggml_time_us();
@@ -115,7 +117,7 @@ int main(int argc, char ** argv) {
// seq_id == 0 : the current input token
// seq_id [1, W] : tokens from the past N - 1 Jacobi iterations
// seq_id [W + 1, W + G] : verification n-grams
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
llama_batch_ext * batch = llama_batch_ext_init(params.n_ctx, W + G + 1);
// target model sampling context
struct common_sampler * smpl = common_sampler_init(model, params.sampling);
@@ -204,10 +206,10 @@ int main(int argc, char ** argv) {
// V V V V V V
// id
{
common_batch_clear(batch);
llama_batch_ext_clear(batch);
// current token - first token of the first level
common_batch_add(batch, id, n_past, seq_id_all, true);
llama_batch_ext_add_text(batch, id, n_past, seq_id_all.data(), seq_id_all.size(), true);
// verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation
{
@@ -230,9 +232,10 @@ int main(int argc, char ** argv) {
const llama_token t = ngrams_observed.tokens[idx + j];
ngrams_cur[g].tokens [j + 1] = t;
ngrams_cur[g].i_batch[j + 1] = batch.n_tokens;
ngrams_cur[g].i_batch[j + 1] = llama_batch_ext_get_n_tokens(batch);
common_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true);
llama_seq_id seq_id = W + 1 + g;
llama_batch_ext_add_text(batch, t, n_past + j + 1, &seq_id, 1, true);
}
}
}
@@ -244,18 +247,20 @@ int main(int argc, char ** argv) {
seq_id_look[j] = i + j + 1;
}
common_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false);
llama_batch_ext_add_text(batch, tokens_j[0][i], n_past + i,
seq_id_look.data(), seq_id_look.size(), false);
}
// fill the rest of the levels
for (int j = 1; j < N - 1; j++) {
for (int i = 0; i < W; i++) {
common_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2);
llama_seq_id seq_id = i + 1;
llama_batch_ext_add_text(batch, tokens_j[j][i], n_past + j + i, &seq_id, 1, j == N - 2);
}
}
}
if (llama_decode(ctx, batch) != 0) {
if (llama_decode_ext(ctx, batch) != 0) {
LOG_ERR("\n\n%s: llama_decode failed - increase KV cache size\n", __func__);
return 1;
}
@@ -438,17 +443,17 @@ int main(int argc, char ** argv) {
// KV cache management
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation
llama_kv_cache_seq_rm(ctx, -1, n_past, -1);
llama_kv_self_seq_rm(ctx, -1, n_past, -1);
if (seq_id_best != 0) {
// if a verification token matched, we keep the best sequence and remove the rest
// this leads to some KV cache fragmentation
llama_kv_cache_seq_keep(ctx, seq_id_best);
llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1);
llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1);
llama_kv_self_seq_keep(ctx, seq_id_best);
llama_kv_self_seq_cp (ctx, seq_id_best, 0, -1, -1);
llama_kv_self_seq_rm (ctx, seq_id_best, -1, -1);
for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
}
}
}
@@ -475,7 +480,7 @@ int main(int argc, char ** argv) {
llama_kv_cache_view_free(&kvc_view);
llama_batch_free(batch);
llama_batch_ext_free(batch);
llama_backend_free();

View File

@@ -91,8 +91,10 @@ int main(int argc, char ** argv){
const auto t_enc_start = ggml_time_us();
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1));
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
llama_decode_ext(ctx, batch0.get());
llama_decode_ext(ctx, batch1.get());
const auto t_enc_end = ggml_time_us();
@@ -108,7 +110,7 @@ int main(int argc, char ** argv){
std::vector<llama_token> draft;
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1);
llama_batch_ext * batch_tgt = llama_batch_ext_init(params.n_ctx, 1);
// debug
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1);
@@ -192,10 +194,11 @@ int main(int argc, char ** argv){
// KV cache management
// clean the cache of draft tokens that weren't accepted
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
llama_kv_self_seq_rm(ctx, 0, n_past, -1);
common_batch_clear(batch_tgt);
common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
const llama_seq_id seq_id = 0;
llama_batch_ext_clear(batch_tgt);
llama_batch_ext_add_text(batch_tgt, draft[0], n_past, &seq_id, 1, true);
// Draft already contains a single token sampled from the model:
GGML_ASSERT(draft.size() == 1);
@@ -205,13 +208,13 @@ int main(int argc, char ** argv){
common_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
for (size_t i = 1; i < draft.size(); ++i) {
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true);
}
t_draft_us += ggml_time_us() - t_start_draft_us;
n_drafted += draft.size() - 1;
llama_decode(ctx, batch_tgt);
llama_decode_ext(ctx, batch_tgt);
++n_past;
draft.erase(draft.begin());
@@ -243,7 +246,7 @@ int main(int argc, char ** argv){
common_sampler_free(smpl);
llama_batch_free(batch_tgt);
llama_batch_ext_free(batch_tgt);
llama_backend_free();

View File

@@ -27,12 +27,24 @@ Once downloaded, place your model in the models folder in llama.cpp.
##### Input prompt (One-and-done)
```bash
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --prompt "Once upon a time"
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf -no-cnv --prompt "Once upon a time"
```
##### Conversation mode (Allow for continuous interaction with the model)
```bash
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf -cnv --chat-template gemma
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --chat-template gemma
```
##### Conversation mode using built-in jinja chat template
```bash
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --jinja
```
##### One-and-done query using jinja with custom system prompt and a starting prompt
```bash
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --jinja --single-turn -sys "You are a helpful assistant" -p "Hello"
```
##### Infinite text from a starting prompt (you can use `Ctrl-C` to stop it):
@@ -44,12 +56,24 @@ Once downloaded, place your model in the models folder in llama.cpp.
##### Input prompt (One-and-done)
```powershell
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --prompt "Once upon a time"
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf -no-cnv --prompt "Once upon a time"
```
##### Conversation mode (Allow for continuous interaction with the model)
```powershell
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf -cnv --chat-template gemma
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --chat-template gemma
```
##### Conversation mode using built-in jinja chat template
```powershell
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --jinja
```
##### One-and-done query using jinja with custom system prompt and a starting prompt
```powershell
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --jinja --single-turn -sys "You are a helpful assistant" -p "Hello"
```
#### Infinite text from a starting prompt (you can use `Ctrl-C` to stop it):
@@ -77,6 +101,8 @@ The `llama-cli` program provides several ways to interact with the LLaMA models
- `--prompt PROMPT`: Provide a prompt directly as a command-line option.
- `--file FNAME`: Provide a file containing a prompt or multiple prompts.
- `--system-prompt PROMPT`: Provide a system prompt (will otherwise use the default one in the chat template (if provided)).
- `--system-prompt-file FNAME`: Provide a file containing a system prompt.
- `--interactive-first`: Run the program in interactive mode and wait for input right away. (More on this below.)
## Interaction
@@ -89,7 +115,10 @@ In interactive mode, users can participate in text generation by injecting their
- `-i, --interactive`: Run the program in interactive mode, allowing users to engage in real-time conversations or provide specific instructions to the model.
- `--interactive-first`: Run the program in interactive mode and immediately wait for user input before starting the text generation.
- `-cnv, --conversation`: Run the program in conversation mode (does not print special tokens and suffix/prefix, use default chat template) (default: false)
- `-cnv, --conversation`: Run the program in conversation mode (does not print special tokens and suffix/prefix, use default or provided chat template) (default: true if chat template found)
- `-no-cnv`: Disable conversation mode (default: false)
- `-st, --single-turn`: Only process a single conversation turn (user input) and then exit.
- `--jinja`: Enable jinja chat template parser, will use the model's built-in template or a user-provided one (default: false)
- `--color`: Enable colorized output to differentiate visually distinguishing between prompts, user input, and generated text.
By understanding and utilizing these interaction options, you can create engaging and dynamic experiences with the LLaMA models, tailoring the text generation process to your specific needs.
@@ -125,6 +154,8 @@ When --in-prefix or --in-suffix options are enabled the chat template ( --chat-t
Example usage: `--chat-template gemma`
`--chat-template-file FNAME`: Load a custom jinja chat template from an external file, useful if the model contains outdated or incompatible template, some examples can be found in models/templates. Up-to-date chat templates can be downloaded from Hugging Face using scripts/get_chat_template.py
## Context Management
During text generation, LLaMA models have a limited context size, which means they can only consider a certain number of tokens from the input and generated text. When the context fills up, the model resets internally, potentially losing some information from the beginning of the conversation or instructions. Context management options help maintain continuity and coherence in these situations.

View File

@@ -354,7 +354,7 @@ int main(int argc, char ** argv) {
}
// remove any "future" tokens that we might have inherited from the previous session
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
llama_kv_self_seq_rm(ctx, -1, n_matching_session_tokens, -1);
}
LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
@@ -548,7 +548,8 @@ int main(int argc, char ** argv) {
int enc_input_size = embd_inp.size();
llama_token * enc_input_buf = embd_inp.data();
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size))) {
auto batch = llama_batch_ext_ptr::init_from_text(enc_input_buf, enc_input_size, 0, 0, true);
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
}
@@ -602,8 +603,8 @@ int main(int argc, char ** argv) {
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
llama_kv_self_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_self_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
n_past -= n_discard;
@@ -626,9 +627,9 @@ int main(int argc, char ** argv) {
LOG_DBG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd);
llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
llama_kv_self_seq_add(ctx, 0, ga_i, n_past, ib*bd);
llama_kv_self_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
llama_kv_self_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
n_past -= bd;
@@ -668,7 +669,8 @@ int main(int argc, char ** argv) {
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
auto batch = llama_batch_ext_ptr::init_from_text(&embd[i], n_eval, n_past, 0, true);
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
}

View File

@@ -174,7 +174,7 @@ int main(int argc, char ** argv) {
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
llama_batch_ext * batch = llama_batch_ext_init(n_ctx, 1);
int32_t n_total_prompt = 0;
int32_t n_total_gen = 0;
@@ -192,17 +192,18 @@ int main(int argc, char ** argv) {
LOG_INF("%s: Evaluating the system prompt ...\n", __func__);
for (int32_t i = 0; i < n_tokens_system; ++i) {
common_batch_add(batch, tokens_system[i], i, { 0 }, false);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, tokens_system[i], i, &seq_id, 1, false);
}
if (llama_decode(ctx, batch) != 0) {
if (llama_decode_ext(ctx, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
}
// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i <= n_clients; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
}
LOG_INF("\n");
@@ -216,7 +217,7 @@ int main(int argc, char ** argv) {
common_kv_cache_dump_view_seqs(kvc_view, 40);
}
common_batch_clear(batch);
llama_batch_ext_clear(batch);
// decode any currently ongoing sequences
for (auto & client : clients) {
@@ -224,26 +225,27 @@ int main(int argc, char ** argv) {
continue;
}
client.i_batch = batch.n_tokens;
client.i_batch = llama_batch_ext_get_n_tokens(batch);
common_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true);
llama_seq_id seq_id = client.id + 1;
llama_batch_ext_add_text(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, &seq_id, 1, true);
client.n_decoded += 1;
}
if (batch.n_tokens == 0) {
if (llama_batch_ext_get_n_tokens(batch) == 0) {
// all sequences have ended - clear the entire KV cache
for (int i = 1; i <= n_clients; ++i) {
llama_kv_cache_seq_rm(ctx, i, -1, -1);
llama_kv_self_seq_rm(ctx, i, -1, -1);
// but keep the system prompt
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
}
LOG_INF("%s: clearing the KV cache\n", __func__);
}
// insert new sequences for decoding
if (cont_batching || batch.n_tokens == 0) {
if (cont_batching || llama_batch_ext_get_n_tokens(batch) == 0) {
for (auto & client : clients) {
if (client.seq_id == -1 && g_seq_id < n_seq) {
client.seq_id = g_seq_id;
@@ -262,17 +264,18 @@ int main(int argc, char ** argv) {
tokens_prompt = common_tokenize(ctx, client.prompt, false);
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
llama_seq_id seq_id = client.id + 1;
llama_batch_ext_add_text(batch, tokens_prompt[i], i + n_tokens_system, &seq_id, 1, false);
}
// extract the logits only for the last token
if (batch.n_tokens > 0) {
batch.logits[batch.n_tokens - 1] = true;
if (llama_batch_ext_get_n_tokens(batch) > 0) {
llama_batch_ext_set_output_last(batch);
}
client.n_prompt = tokens_prompt.size();
client.n_decoded = 0;
client.i_batch = batch.n_tokens - 1;
client.i_batch = llama_batch_ext_get_n_tokens(batch) - 1;
LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id);
@@ -286,14 +289,15 @@ int main(int argc, char ** argv) {
}
}
if (batch.n_tokens == 0) {
if (llama_batch_ext_get_n_tokens(batch) == 0) {
break;
}
// process in chunks of params.n_batch
int32_t n_batch = params.n_batch;
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
int32_t n_tokens_in_batch = llama_batch_ext_get_n_tokens(batch);
for (int32_t i = 0; i < (int32_t) n_tokens_in_batch; i += n_batch) {
// experiment: process in powers of 2
//if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
// n_batch /= 2;
@@ -301,19 +305,11 @@ int main(int argc, char ** argv) {
// continue;
//}
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
const int32_t n_tokens = std::min(n_batch, (int32_t) (n_tokens_in_batch - i));
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
};
const int ret = llama_decode(ctx, batch_view);
llama_batch_ext * batch_view = llama_batch_ext_get_view(batch, i, n_tokens);
const int ret = llama_decode_ext(ctx, batch_view);
llama_batch_ext_free(batch_view);
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
// if you get here, it means the KV cache is full - try increasing it via the context size
@@ -372,8 +368,8 @@ int main(int argc, char ** argv) {
}
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
llama_kv_self_seq_rm(ctx, client.id + 1, -1, -1);
llama_kv_self_seq_cp(ctx, 0, client.id + 1, -1, -1);
const auto t_main_end = ggml_time_us();
@@ -417,7 +413,7 @@ int main(int argc, char ** argv) {
// TODO: print sampling/grammar timings for all clients
llama_perf_context_print(ctx);
llama_batch_free(batch);
llama_batch_ext_free(batch);
llama_backend_free();

View File

@@ -2,6 +2,7 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "llama-cpp.h"
#include <cmath>
#include <cstdio>
@@ -122,7 +123,7 @@ int main(int argc, char ** argv) {
LOG_INF("prompt tokens: %d\n", n_tokens_all);
//LOG_INF("prompt: %s\n", params.prompt.c_str());
llama_batch batch = llama_batch_init(params.n_batch, 0, 1);
llama_batch_ext_ptr batch(llama_batch_ext_init(params.n_batch, 1));
int n_past = 0;
@@ -133,24 +134,25 @@ int main(int argc, char ** argv) {
const int ib = i/n_batch - 1;
const int bd = n_batch_grp*(n_grp - 1);
llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
llama_kv_cache_update (ctx);
llama_kv_self_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
llama_kv_self_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
llama_kv_self_update (ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
}
common_batch_clear(batch);
llama_batch_ext_clear(batch.get());
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false);
}
if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true;
llama_batch_ext_set_output_last(batch.get());
}
if (llama_decode(ctx, batch) != 0) {
if (llama_decode_ext(ctx, batch.get()) != 0) {
LOG_INF("%s: llama_decode() failed\n", __func__);
return 1;
}
@@ -167,24 +169,25 @@ int main(int argc, char ** argv) {
LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard);
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
//llama_kv_cache_defrag (ctx);
llama_kv_cache_update (ctx);
llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
//llama_kv_self_defrag (ctx);
llama_kv_self_update (ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
common_batch_clear(batch);
llama_batch_ext_clear(batch.get());
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false);
}
if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true;
llama_batch_ext_set_output_last(batch.get());
}
if (llama_decode(ctx, batch) != 0) {
if (llama_decode_ext(ctx, batch.get()) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
}
@@ -198,12 +201,12 @@ int main(int argc, char ** argv) {
if (n_discard > 0) {
LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
//llama_kv_cache_defrag (ctx);
llama_kv_cache_update (ctx);
llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
//llama_kv_self_defrag (ctx);
llama_kv_self_update (ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
}
}
@@ -223,7 +226,7 @@ int main(int argc, char ** argv) {
while (n_cur <= n_len) {
// sample the next token
{
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(batch.get()) - 1);
// is it an end of generation?
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
@@ -237,16 +240,17 @@ int main(int argc, char ** argv) {
n_decode += 1;
// prepare the next batch
common_batch_clear(batch);
llama_batch_ext_clear(batch.get());
// push this new token for next evaluation
common_batch_add(batch, new_token_id, n_past++, { 0 }, true);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch.get(), new_token_id, n_past++, &seq_id, 1, true);
}
n_cur += 1;
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
@@ -266,8 +270,6 @@ int main(int argc, char ** argv) {
llama_sampler_free(smpl);
llama_batch_free(batch);
llama_free(ctx);
llama_model_free(model);

View File

@@ -361,23 +361,22 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
llama_batch batch = llama_batch_init(n_batch, 0, 1);
common_batch batch(n_batch, 1);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
common_batch_clear(batch);
batch.clear();
for (int i = 0; i < batch_size; i++) {
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true);
}
//LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch.get())) {
//LOG_ERR("%s : failed to eval\n", __func__);
llama_batch_free(batch);
return {tokens, -1, logit_history, prob_history};
}
@@ -397,8 +396,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
}
}
llama_batch_free(batch);
const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) {
@@ -504,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);
common_batch batch(std::min(n_batch, n_ctx*n_seq), 1);
std::vector<float> logits;
if (num_batches > 1) {
@@ -547,7 +544,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
@@ -555,7 +552,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
int n_outputs = 0;
batch.n_tokens = 0;
batch.clear();
for (int seq = 0; seq < n_seq_batch; seq++) {
int seq_start = batch_start + seq*n_ctx;
@@ -568,22 +565,18 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
}
for (int k = 0; k < batch_size; ++k) {
const int idx = seq*n_ctx + k;
batch.token [idx] = tokens[seq_start + k];
batch.pos [idx] = j*n_batch + k;
batch.n_seq_id[idx] = 1;
batch.seq_id [idx][0] = seq;
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
const llama_pos pos = j*n_batch + k;
bool output = pos >= first;
batch.add_text(tokens[seq_start + k], pos, seq, output);
n_outputs += batch.logits[idx] != 0;
n_outputs += output ? 1 : 0;
}
batch.n_tokens += batch_size;
// restore the original token in case it was set to BOS
tokens[seq_start] = token_org;
}
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch.get())) {
LOG_INF("%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}
@@ -653,36 +646,23 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
LOG_ERR("Unexpected negative standard deviation of log(prob)\n");
}
llama_batch_free(batch);
return {tokens, ppl, logit_history, prob_history};
}
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
static bool decode_helper(llama_context * ctx, common_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
int prev_outputs = 0;
for (int i = 0; i < (int) batch.n_tokens; i += n_batch) {
const int n_tokens = std::min<int>(n_batch, batch.n_tokens - i);
for (int i = 0; i < (int) batch.get_n_tokens(); i += n_batch) {
const int n_tokens = std::min<int>(n_batch, batch.get_n_tokens() - i);
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
};
common_batch batch_view = batch.get_view(i, n_tokens);
const int ret = llama_decode(ctx, batch_view);
const int ret = llama_decode_ext(ctx, batch_view.get());
if (ret != 0) {
LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
return false;
}
int n_outputs = 0;
for (int i = 0; i < n_tokens; ++i) {
n_outputs += batch_view.logits[i] != 0;
}
int n_outputs = batch_view.n_outputs;
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
@@ -863,7 +843,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
llama_batch batch = llama_batch_init(n_ctx, 0, 4);
common_batch batch(n_ctx, 4);
std::vector<float> tok_logits(n_vocab);
// TODO: this could be made smaller; it's currently the worst-case size
@@ -879,7 +859,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
size_t i1 = i0;
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
common_batch_clear(batch);
batch.clear();
// batch as much tasks as possible into the available context
// each task has 4 unique sequence ids - one for each ending
@@ -895,9 +875,9 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
}
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
batch.add_text_multi_seq(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
}
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
llama_batch_ext_set_output_last(batch.get());
n_logits += 1;
for (int s = 0; s < 4; ++s) {
@@ -905,7 +885,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
// TODO: don't evaluate the last token of each sequence
for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
const bool needs_logits = i < seq_tokens_size - 1;
common_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
batch.add_text_multi_seq(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
n_logits += needs_logits;
}
}
@@ -924,7 +904,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
return;
}
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@@ -992,8 +972,6 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
i0 = i1 - 1;
}
llama_batch_free(batch);
LOG("\n");
}
@@ -1147,7 +1125,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
const int max_tasks_per_batch = 128;
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
llama_batch batch = llama_batch_init(n_ctx, 0, 2);
common_batch batch(n_ctx, 2);
std::vector<float> tok_logits(n_vocab);
// TODO: this could be made smaller; it's currently the worst-case size
@@ -1166,7 +1144,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
size_t i1 = i0;
size_t i_logits = 0;
common_batch_clear(batch);
batch.clear();
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
int n_logits = 0;
@@ -1176,15 +1154,15 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
}
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
batch.add_text_multi_seq(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
}
batch.logits[batch.n_tokens - 1] = true;
llama_batch_ext_set_output_last(batch.get());
n_logits += 1;
for (int s = 0; s < 2; ++s) {
// TODO: end before the last token, no need to predict past the end of the sequences
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
common_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
batch.add_text_multi_seq(data[i1].seq_tokens[s][i], i, { s0 + s }, true);
n_logits += 1;
}
}
@@ -1203,7 +1181,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
return;
}
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@@ -1501,7 +1479,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
common_batch batch(n_ctx, max_seq);
std::vector<float> tok_logits(n_vocab);
std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
@@ -1521,7 +1499,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
size_t i1 = i0;
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
common_batch_clear(batch);
batch.clear();
// batch as much tasks as possible into the available context
// each task has 4 unique sequence ids - one for each ending
@@ -1544,9 +1522,9 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
for (size_t i = 0; i < cur_task.common_prefix; ++i) {
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
batch.add_text_multi_seq(cur_task.seq_tokens[0][i], i, batch_indeces, false);
}
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix
n_logits += 1;
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
@@ -1554,7 +1532,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
// TODO: don't evaluate the last token of each sequence
for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
const bool needs_logits = i < seq_tokens_size - 1;
common_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
batch.add_text_multi_seq(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
n_logits += needs_logits;
}
}
@@ -1575,7 +1553,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
return;
}
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@@ -1653,8 +1631,6 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
i0 = i1 - 1;
}
llama_batch_free(batch);
if (n_done < 100 && (params.multiple_choice_tasks != 0 && params.multiple_choice_tasks < (size_t)n_task)) return;
float p = 1.f*n_correct/n_done;
@@ -1765,9 +1741,9 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
}
// clear the KV cache
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
llama_batch batch = llama_batch_init(n_batch, 0, 1);
common_batch batch(n_batch, 1);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
@@ -1781,14 +1757,13 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
tokens[batch_start] = llama_vocab_bos(vocab);
}
common_batch_clear(batch);
batch.clear();
for (int i = 0; i < batch_size; i++) {
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
batch.add_text_multi_seq(tokens[batch_start + i], j*n_batch + i, {0}, true);
}
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
llama_batch_free(batch);
return;
}
@@ -1801,8 +1776,6 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
}
}
llama_batch_free(batch);
const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) {

View File

@@ -1,6 +1,6 @@
#include "ggml.h"
#include "llama.h"
#include "llama-context.h"
#include "llama-model.h"
#include "common.h"
#include <algorithm>
@@ -328,7 +328,7 @@ int main(int argc, char ** argv) {
}
}
const auto & tensors = llama_internal_get_tensor_map(ctx);
const auto & tensors = llama_internal_get_tensor_map(model);
// check layer tensors
int included_layers = 0;

View File

@@ -74,40 +74,56 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
return chunks;
}
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
static void batch_add_seq(common_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
size_t n_tokens = tokens.size();
for (size_t i = 0; i < n_tokens; i++) {
common_batch_add(batch, tokens[i], i, { seq_id }, true);
batch.add_text(tokens[i], i, seq_id, true);
}
}
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm = 2) {
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
const struct llama_model * model = llama_get_model(ctx);
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
// run model
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
if (llama_decode(ctx, batch) < 0) {
LOG_ERR("%s : failed to decode\n", __func__);
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq);
if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
// encoder-only model
if (llama_encode_ext(ctx, batch.get()) < 0) {
LOG_ERR("%s : failed to encode\n", __func__);
}
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
// decoder-only model
if (llama_decode_ext(ctx, batch.get()) < 0) {
LOG_ERR("%s : failed to decode\n", __func__);
}
}
for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) {
if (!batch.tokens[i].logits) {
continue;
}
// try to get sequence embeddings - supported only when pooling_type is not NONE
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
const float * embd = nullptr;
int embd_pos = 0;
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
// try to get token embeddings
embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) {
LOG_ERR("%s: failed to get embeddings for token %d\n", __func__, i);
continue;
}
embd_pos = i;
GGML_ASSERT(embd != NULL && "failed to get token embeddings");
} else {
// try to get sequence embeddings - supported only when pooling_type is not NONE
embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id);
embd_pos = batch.tokens[i].seq_id;
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
}
float * out = output + batch.seq_id[i][0] * n_embd;
common_embd_normalize(embd, out, n_embd, 2);
float * out = output + embd_pos * n_embd;
common_embd_normalize(embd, out, n_embd, embd_norm);
}
}
@@ -214,7 +230,7 @@ int main(int argc, char ** argv) {
// initialize batch
const int n_chunks = chunks.size();
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
struct common_batch batch = common_batch(n_batch, 1);
// allocate output
const int n_embd = llama_model_n_embd(model);
@@ -231,10 +247,10 @@ int main(int argc, char ** argv) {
const uint64_t n_toks = inp.size();
// encode if at capacity
if (batch.n_tokens + n_toks > n_batch) {
if (llama_batch_ext_get_n_tokens(batch.get()) + n_toks > n_batch) {
float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd);
common_batch_clear(batch);
batch.clear();
p += s;
s = 0;
}
@@ -255,7 +271,7 @@ int main(int argc, char ** argv) {
chunks[i].tokens.clear();
}
struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1);
struct common_batch query_batch = common_batch(n_batch, 1);
// start loop, receive query and return top k similar chunks based on cosine similarity
std::string query;
@@ -269,7 +285,7 @@ int main(int argc, char ** argv) {
std::vector<float> query_emb(n_embd, 0);
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);
common_batch_clear(query_batch);
query_batch.clear();
// compute cosine similarities
{
@@ -299,6 +315,5 @@ int main(int argc, char ** argv) {
llama_perf_context_print(ctx);
// clean up
llama_batch_free(query_batch);
llama_backend_free();
}

View File

@@ -79,6 +79,7 @@ class Opt {
ctx_params = llama_context_default_params();
model_params = llama_model_default_params();
context_size_default = ctx_params.n_batch;
n_threads_default = ctx_params.n_threads;
ngl_default = model_params.n_gpu_layers;
common_params_sampling sampling;
temperature_default = sampling.temp;
@@ -104,6 +105,7 @@ class Opt {
ctx_params.n_batch = context_size >= 0 ? context_size : context_size_default;
ctx_params.n_ctx = ctx_params.n_batch;
ctx_params.n_threads = ctx_params.n_threads_batch = n_threads >= 0 ? n_threads : n_threads_default;
model_params.n_gpu_layers = ngl >= 0 ? ngl : ngl_default;
temperature = temperature >= 0 ? temperature : temperature_default;
@@ -116,12 +118,12 @@ class Opt {
std::string chat_template_file;
std::string user;
bool use_jinja = false;
int context_size = -1, ngl = -1;
int context_size = -1, ngl = -1, n_threads = -1;
float temperature = -1;
bool verbose = false;
private:
int context_size_default = -1, ngl_default = -1;
int context_size_default = -1, ngl_default = -1, n_threads_default = -1;
float temperature_default = -1;
bool help = false;
@@ -159,53 +161,94 @@ class Opt {
return 0;
}
int parse_options_with_value(int argc, const char ** argv, int & i, bool & options_parsing) {
if (options_parsing && (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0)) {
if (handle_option_with_value(argc, argv, i, context_size) == 1) {
return 1;
}
} else if (options_parsing &&
(strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "-ngl") == 0 || strcmp(argv[i], "--ngl") == 0)) {
if (handle_option_with_value(argc, argv, i, ngl) == 1) {
return 1;
}
} else if (options_parsing && (strcmp(argv[i], "-t") == 0 || strcmp(argv[i], "--threads") == 0)) {
if (handle_option_with_value(argc, argv, i, n_threads) == 1) {
return 1;
}
} else if (options_parsing && strcmp(argv[i], "--temp") == 0) {
if (handle_option_with_value(argc, argv, i, temperature) == 1) {
return 1;
}
} else if (options_parsing && strcmp(argv[i], "--chat-template-file") == 0) {
if (handle_option_with_value(argc, argv, i, chat_template_file) == 1) {
return 1;
}
use_jinja = true;
} else {
return 2;
}
return 0;
}
int parse_options(const char ** argv, int & i, bool & options_parsing) {
if (options_parsing && (parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) {
verbose = true;
} else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
use_jinja = true;
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
help = true;
return 0;
} else if (options_parsing && strcmp(argv[i], "--") == 0) {
options_parsing = false;
} else {
return 2;
}
return 0;
}
int parse_positional_args(const char ** argv, int & i, int & positional_args_i) {
if (positional_args_i == 0) {
if (!argv[i][0] || argv[i][0] == '-') {
return 1;
}
++positional_args_i;
model_ = argv[i];
} else if (positional_args_i == 1) {
++positional_args_i;
user = argv[i];
} else {
user += " " + std::string(argv[i]);
}
return 0;
}
int parse(int argc, const char ** argv) {
bool options_parsing = true;
for (int i = 1, positional_args_i = 0; i < argc; ++i) {
if (options_parsing && (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0)) {
if (handle_option_with_value(argc, argv, i, context_size) == 1) {
return 1;
}
} else if (options_parsing &&
(strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "-ngl") == 0 || strcmp(argv[i], "--ngl") == 0)) {
if (handle_option_with_value(argc, argv, i, ngl) == 1) {
return 1;
}
} else if (options_parsing && strcmp(argv[i], "--temp") == 0) {
if (handle_option_with_value(argc, argv, i, temperature) == 1) {
return 1;
}
} else if (options_parsing &&
(parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) {
verbose = true;
} else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
use_jinja = true;
} else if (options_parsing && strcmp(argv[i], "--chat-template-file") == 0){
if (handle_option_with_value(argc, argv, i, chat_template_file) == 1) {
return 1;
}
use_jinja = true;
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
help = true;
return 0;
} else if (options_parsing && strcmp(argv[i], "--") == 0) {
options_parsing = false;
} else if (positional_args_i == 0) {
if (!argv[i][0] || argv[i][0] == '-') {
return 1;
}
int ret = parse_options_with_value(argc, argv, i, options_parsing);
if (ret == 0) {
continue;
} else if (ret == 1) {
return ret;
}
++positional_args_i;
model_ = argv[i];
} else if (positional_args_i == 1) {
++positional_args_i;
user = argv[i];
} else {
user += " " + std::string(argv[i]);
ret = parse_options(argv, i, options_parsing);
if (ret == 0) {
continue;
} else if (ret == 1) {
return ret;
}
if (parse_positional_args(argv, i, positional_args_i)) {
return 1;
}
}
if (model_.empty()){
if (model_.empty()) {
return 1;
}
@@ -232,6 +275,8 @@ class Opt {
" Number of GPU layers (default: %d)\n"
" --temp <value>\n"
" Temperature (default: %.1f)\n"
" -t, --threads <value>\n"
" Number of threads to use during generation (default: %d)\n"
" -v, --verbose, --log-verbose\n"
" Set verbosity level to infinity (i.e. log all messages, useful for debugging)\n"
" -h, --help\n"
@@ -260,7 +305,7 @@ class Opt {
" llama-run file://some-file3.gguf\n"
" llama-run --ngl 999 some-file4.gguf\n"
" llama-run --ngl 999 some-file5.gguf Hello World\n",
context_size_default, ngl_default, temperature_default);
context_size_default, ngl_default, temperature_default, n_threads_default);
}
};
@@ -595,6 +640,7 @@ class LlamaData {
std::vector<llama_chat_message> messages; // TODO: switch to common_chat_msg
std::list<std::string> msg_strs;
std::vector<char> fmtted;
llama_pos n_past = 0;
int init(Opt & opt) {
model = initialize_model(opt);
@@ -891,7 +937,7 @@ static int apply_chat_template(const struct common_chat_templates * tmpls, Llama
// Function to tokenize the prompt
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
const bool is_first = llama_get_kv_cache_used_cells(llama_data.context.get()) == 0;
const bool is_first = llama_kv_self_used_cells(llama_data.context.get()) == 0;
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
prompt_tokens.resize(n_prompt_tokens);
@@ -905,10 +951,10 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt
}
// Check if we have enough space in the context to evaluate this batch
static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) {
static int check_context_size(const llama_context_ptr & ctx, const llama_batch_ext_ptr & batch) {
const int n_ctx = llama_n_ctx(ctx.get());
const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get());
if (n_ctx_used + batch.n_tokens > n_ctx) {
const int n_ctx_used = llama_kv_self_used_cells(ctx.get());
if (n_ctx_used + llama_batch_ext_get_n_tokens(batch.get()) > n_ctx) {
printf(LOG_COL_DEFAULT "\n");
printe("context size exceeded\n");
return 1;
@@ -946,15 +992,17 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
}
// prepare a batch for the prompt
llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size());
auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), llama_data.n_past, 0, true);
llama_token new_token_id;
while (true) {
check_context_size(llama_data.context, batch);
if (llama_decode(llama_data.context.get(), batch)) {
if (llama_decode_ext(llama_data.context.get(), batch.get())) {
printe("failed to decode\n");
return 1;
}
llama_data.n_past += llama_batch_ext_get_n_tokens(batch.get());
// sample the next token, check is it an end of generation?
new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1);
if (llama_vocab_is_eog(vocab, new_token_id)) {
@@ -969,7 +1017,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
print_word_and_concatenate_to_response(piece, response);
// prepare the next batch with the sampled token
batch = llama_batch_get_one(&new_token_id, 1);
batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, llama_data.n_past, 0, true));
}
printf(LOG_COL_DEFAULT);

View File

@@ -15,7 +15,7 @@ int main(int argc, char ** argv) {
return 1;
}
print_build_info();
common_init();
if (params.n_predict < 0) {
params.n_predict = 16;
@@ -48,15 +48,11 @@ int main(int argc, char ** argv) {
auto tokens = common_tokenize(ctx, params.prompt, true);
// prepare the batch
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
for (size_t i = 0; i < tokens.size(); i++) {
common_batch_add(batch, tokens[i], i, {0}, false);
}
batch.logits[batch.n_tokens - 1] = true; // generate next token
llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true);
// evaluate prompt
llama_decode(ctx, batch);
n_past += batch.n_tokens;
llama_decode_ext(ctx, batch);
n_past += llama_batch_ext_get_n_tokens(batch);
// save state (rng, logits, embedding and kv_cache) to file
{
@@ -83,12 +79,13 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str.c_str());
result0 += next_token_str;
common_batch_clear(batch);
common_batch_add(batch, next_token, n_past, {0}, true);
llama_batch_ext_clear(batch);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_batch_free(batch);
llama_batch_ext_free(batch);
return 1;
}
n_past += 1;
@@ -135,12 +132,13 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str.c_str());
result1 += next_token_str;
common_batch_clear(batch);
common_batch_add(batch, next_token, n_past, {0}, true);
llama_batch_ext_clear(batch);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
if (llama_decode(ctx2, batch)) {
if (llama_decode_ext(ctx2, batch)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_batch_free(batch);
llama_batch_ext_free(batch);
return 1;
}
n_past += 1;
@@ -196,7 +194,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
// erase whole kv
llama_kv_cache_clear(ctx3);
llama_kv_self_clear(ctx3);
fprintf(stderr, "%s : kv cache cleared\n", __func__);
// restore kv into seq 1
@@ -216,12 +214,13 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str.c_str());
result2 += next_token_str;
common_batch_clear(batch);
common_batch_add(batch, next_token, n_past, {1}, true);
llama_batch_ext_clear(batch);
llama_seq_id seq_id = 1; // seq 1 instead of 0
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
if (llama_decode(ctx3, batch)) {
if (llama_decode_ext(ctx3, batch)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_batch_free(batch);
llama_batch_ext_free(batch);
return 1;
}
n_past += 1;
@@ -233,7 +232,7 @@ int main(int argc, char ** argv) {
llama_sampler_free(smpl2);
llama_sampler_free(smpl3);
llama_batch_free(batch);
llama_batch_ext_free(batch);
if (result0 != result2) {
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);

View File

@@ -1224,7 +1224,7 @@ struct server_slot {
// only used for completion/embedding/infill/rerank
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
llama_batch batch_spec = {};
common_batch batch_spec;
llama_context * ctx = nullptr;
llama_context * ctx_dft = nullptr;
@@ -1796,7 +1796,7 @@ struct server_context {
llama_context_params cparams_dft;
llama_batch batch = {};
common_batch batch;
bool clean_kv_cache = true;
bool add_bos_token = true;
@@ -1829,11 +1829,7 @@ struct server_context {
common_speculative_free(slot.spec);
slot.spec = nullptr;
llama_batch_free(slot.batch_spec);
}
llama_batch_free(batch);
}
bool load_model(const common_params & params) {
@@ -1872,6 +1868,10 @@ struct server_context {
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
params_dft.n_parallel = 1;
// force F16 KV cache for the draft model for extra performance
params_dft.cache_type_k = GGML_TYPE_F16;
params_dft.cache_type_v = GGML_TYPE_F16;
llama_init_dft = common_init_from_params(params_dft);
model_dft = llama_init_dft.model.get();
@@ -1892,10 +1892,6 @@ struct server_context {
cparams_dft = common_context_params_to_llama(params_dft);
cparams_dft.n_batch = n_ctx_dft;
// force F16 KV cache for the draft model for extra performance
cparams_dft.type_k = GGML_TYPE_F16;
cparams_dft.type_v = GGML_TYPE_F16;
// the context is not needed - we will create one for each slot
llama_init_dft.context.reset();
}
@@ -1926,7 +1922,7 @@ struct server_context {
slot.n_predict = params_base.n_predict;
if (model_dft) {
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
slot.batch_spec = common_batch(params_base.speculative.n_max + 1, 1);
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
if (slot.ctx_dft == nullptr) {
@@ -1951,7 +1947,7 @@ struct server_context {
slot.reset();
slots.push_back(slot);
slots.push_back(std::move(slot));
}
default_generation_settings_for_props = slots[0].to_json();
@@ -1962,7 +1958,7 @@ struct server_context {
const int32_t n_batch = llama_n_batch(ctx);
// only a single seq_id per token is needed
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
batch = common_batch(std::max(n_batch, params_base.n_parallel), 1);
}
metrics.init();
@@ -2040,6 +2036,18 @@ struct server_context {
return ret;
}
bool can_be_detokenized(const struct llama_context * ctx, const std::vector<llama_token> & tokens) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
for (const auto & token : tokens) {
if (token < 0 || token >= n_vocab) {
return false;
}
}
return true;
}
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
slot.reset();
slot.id_task = task.id;
@@ -2054,6 +2062,11 @@ struct server_context {
slot.lora = task.params.lora;
}
bool can_detokenize = can_be_detokenized(ctx, slot.prompt_tokens);
if (!can_detokenize) {
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
return false;
}
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
@@ -2080,9 +2093,7 @@ struct server_context {
}
if (slot.ctx_dft) {
llama_batch_free(slot.batch_spec);
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
slot.batch_spec = common_batch(slot.params.speculative.n_max + 1, 1);
}
slot.state = SLOT_STATE_STARTED;
@@ -2096,7 +2107,7 @@ struct server_context {
SRV_DBG("%s", "clearing KV cache\n");
// clear the entire KV cache
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
clean_kv_cache = false;
}
@@ -2390,7 +2401,7 @@ struct server_context {
queue_results.send(std::move(res));
}
void send_embedding(const server_slot & slot, const llama_batch & batch) {
void send_embedding(const server_slot & slot, common_batch & batch) {
auto res = std::make_unique<server_task_result_embd>();
res->id = slot.id_task;
res->index = slot.index;
@@ -2401,18 +2412,19 @@ struct server_context {
std::vector<float> embd_res(n_embd, 0.0f);
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
for (int i = 0; i < batch.get_n_tokens(); ++i) {
auto tok = batch.tokens[i];
if (!tok.logits || tok.seq_id != slot.id) {
continue;
}
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
}
if (embd == NULL) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id);
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
continue;
@@ -2433,24 +2445,25 @@ struct server_context {
queue_results.send(std::move(res));
}
void send_rerank(const server_slot & slot, const llama_batch & batch) {
void send_rerank(const server_slot & slot, common_batch & batch) {
auto res = std::make_unique<server_task_result_rerank>();
res->id = slot.id_task;
res->index = slot.index;
res->n_tokens = slot.n_prompt_tokens;
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
for (int i = 0; i < batch.get_n_tokens(); ++i) {
auto tok = batch.tokens[i];
if (!tok.logits || tok.seq_id != slot.id) {
continue;
}
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
}
if (embd == NULL) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id);
res->score = -1e6;
continue;
@@ -2638,8 +2651,8 @@ struct server_context {
res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
res->t_start = metrics.t_start;
res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx);
res->kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx);
res->kv_cache_tokens_count = llama_kv_self_n_tokens(ctx);
res->kv_cache_used_cells = llama_kv_self_used_cells(ctx);
res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
res->t_prompt_processing_total = metrics.t_prompt_processing_total;
@@ -2755,7 +2768,7 @@ struct server_context {
// Erase token cache
const size_t n_erased = slot->cache_tokens.size();
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
llama_kv_self_seq_rm(ctx, slot->id, -1, -1);
slot->cache_tokens.clear();
auto res = std::make_unique<server_task_result_slot_erase>();
@@ -2823,8 +2836,8 @@ struct server_context {
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
@@ -2841,7 +2854,7 @@ struct server_context {
}
// start populating the batch for this iteration
common_batch_clear(batch);
batch.clear();
// track if given slot can be batched with slots already in the batch
server_slot * slot_batched = nullptr;
@@ -2863,9 +2876,9 @@ struct server_context {
continue;
}
slot.i_batch = batch.n_tokens;
slot.i_batch = batch.get_n_tokens();
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
batch.add_text(slot.sampled, slot.n_past, slot.id, true);
slot.n_past += 1;
@@ -2882,7 +2895,7 @@ struct server_context {
int32_t n_ubatch = llama_n_ubatch(ctx);
// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
if (params_base.cont_batching || batch.get_n_tokens() == 0) {
for (auto & slot : slots) {
// check if we can batch this slot with the previous one
if (slot.is_processing()) {
@@ -3015,8 +3028,8 @@ struct server_context {
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c);
llama_kv_cache_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift);
llama_kv_self_seq_rm (ctx, slot.id, head_p, head_c);
llama_kv_self_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift);
for (size_t i = 0; i < n_match; i++) {
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
@@ -3048,15 +3061,15 @@ struct server_context {
// non-causal tasks require to fit the entire prompt in the physical batch
if (slot.is_non_causal()) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
if (batch.get_n_tokens() + slot.n_prompt_tokens > n_batch) {
continue;
}
}
// keep only the common part
if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
if (!llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1)) {
// could not partially delete (likely using a non-Transformer model)
llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
llama_kv_self_seq_rm(ctx, slot.id, -1, -1);
// there is no common part left
slot.n_past = 0;
@@ -3068,11 +3081,11 @@ struct server_context {
slot.cache_tokens.resize(slot.n_past);
// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
while (slot.n_past < slot.n_prompt_tokens && batch.get_n_tokens() < n_batch) {
// without pooling, we want to output the embeddings for all the tokens in the batch
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
batch.add_text(prompt_tokens[slot.n_past], slot.n_past, slot.id, need_embd);
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
@@ -3082,13 +3095,13 @@ struct server_context {
slot.n_past++;
}
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.get_n_tokens(), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
// entire prompt has been processed
if (slot.n_past == slot.n_prompt_tokens) {
slot.state = SLOT_STATE_DONE_PROMPT;
GGML_ASSERT(batch.n_tokens > 0);
GGML_ASSERT(batch.get_n_tokens() > 0);
common_sampler_reset(slot.smpl);
@@ -3098,27 +3111,27 @@ struct server_context {
}
// extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true;
batch.set_logits_last();
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
slot.i_batch = batch.get_n_tokens() - 1;
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.get_n_tokens());
}
}
if (batch.n_tokens >= n_batch) {
if (batch.get_n_tokens() >= n_batch) {
break;
}
}
}
if (batch.n_tokens == 0) {
if (batch.get_n_tokens() == 0) {
SRV_WRN("%s", "no tokens to decode\n");
return;
}
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
SRV_DBG("decoding batch, n_tokens = %d\n", batch.get_n_tokens());
if (slot_batched) {
// make sure we're in the right embedding mode
@@ -3128,20 +3141,12 @@ struct server_context {
}
// process the created batch of tokens
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
for (int32_t i = 0; i < batch.get_n_tokens(); i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.get_n_tokens() - i);
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
};
common_batch batch_view = batch.get_view(i, n_tokens);
const int ret = llama_decode(ctx, batch_view);
const int ret = llama_decode_ext(ctx, batch_view.get());
metrics.on_decoded(slots);
if (ret != 0) {
@@ -3276,16 +3281,16 @@ struct server_context {
}
// construct the speculation batch
common_batch_clear(slot.batch_spec);
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
slot.batch_spec.clear();
slot.batch_spec.add_text(id, slot.n_past, slot.id, true);
for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
slot.batch_spec.add_text(draft[i], slot.n_past + 1 + i, slot.id, true);
}
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.get_n_tokens());
llama_decode(ctx, slot.batch_spec);
llama_decode_ext(ctx, slot.batch_spec.get());
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
@@ -3296,7 +3301,7 @@ struct server_context {
slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1);
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;

View File

@@ -302,7 +302,7 @@ class ServerPreset:
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/stories260K.gguf"
server.model_alias = "tinyllama-2"
server.n_ctx = 256
server.n_ctx = 512
server.n_batch = 32
server.n_slots = 2
server.n_predict = 64

View File

@@ -621,7 +621,9 @@ static json oaicompat_completion_params_parse(
llama_params["chat_format"] = static_cast<int>(chat_params.format);
llama_params["prompt"] = chat_params.prompt;
llama_params["grammar"] = chat_params.grammar;
if (!chat_params.grammar.empty()) {
llama_params["grammar"] = chat_params.grammar;
}
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
auto grammar_triggers = json::array();
for (const auto & trigger : chat_params.grammar_triggers) {

View File

@@ -98,7 +98,7 @@ int main(int argc, char ** argv) {
auto generate = [&](const std::string & prompt) {
std::string response;
const bool is_first = llama_get_kv_cache_used_cells(ctx) == 0;
const bool is_first = llama_kv_self_used_cells(ctx) == 0;
// tokenize the prompt
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
@@ -108,19 +108,22 @@ int main(int argc, char ** argv) {
}
// prepare a batch for the prompt
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
llama_pos n_past = 0;
llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), n_past, 0, true);
n_past += llama_batch_ext_get_n_tokens(batch);
llama_token new_token_id;
while (true) {
// check if we have enough space in the context to evaluate this batch
int n_ctx = llama_n_ctx(ctx);
int n_ctx_used = llama_get_kv_cache_used_cells(ctx);
if (n_ctx_used + batch.n_tokens > n_ctx) {
int n_ctx_used = llama_kv_self_used_cells(ctx);
if (n_ctx_used + llama_batch_ext_get_n_tokens(batch) > n_ctx) {
printf("\033[0m\n");
fprintf(stderr, "context size exceeded\n");
exit(0);
}
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch)) {
GGML_ABORT("failed to decode\n");
}
@@ -144,9 +147,14 @@ int main(int argc, char ** argv) {
response += piece;
// prepare the next batch with the sampled token
batch = llama_batch_get_one(&new_token_id, 1);
llama_batch_ext_clear(batch);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, new_token_id, n_past, &seq_id, 1, true);
n_past++;
}
llama_batch_ext_free(batch);
return response;
};

View File

@@ -143,7 +143,7 @@ int main(int argc, char ** argv) {
// prepare a batch for the prompt
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0, true);
// main loop
@@ -151,14 +151,14 @@ int main(int argc, char ** argv) {
int n_decode = 0;
llama_token new_token_id;
for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + n_predict; ) {
for (int n_pos = 0; n_pos + llama_batch_ext_get_n_tokens(batch) < n_prompt + n_predict; ) {
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
n_pos += batch.n_tokens;
n_pos += llama_batch_ext_get_n_tokens(batch);
// sample the next token
{
@@ -180,7 +180,9 @@ int main(int argc, char ** argv) {
fflush(stdout);
// prepare the next batch with the sampled token
batch = llama_batch_get_one(&new_token_id, 1);
llama_batch_ext_clear(batch);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, new_token_id, n_pos, &seq_id, 1, true);
n_decode += 1;
}
@@ -198,6 +200,7 @@ int main(int argc, char ** argv) {
llama_perf_context_print(ctx);
fprintf(stderr, "\n");
llama_batch_ext_free(batch);
llama_sampler_free(smpl);
llama_free(ctx);
llama_model_free(model);

View File

@@ -113,7 +113,8 @@ int main(int argc, char ** argv) {
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
// eval the prompt
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
auto batch = llama_batch_ext_ptr::init_from_text(inp.data(), inp.size() - 1, 0, 0, true);
llama_decode_ext(ctx_tgt, batch.get());
// note: keep the last token separate!
llama_token id_last = inp.back();
@@ -132,7 +133,7 @@ int main(int argc, char ** argv) {
struct common_speculative * spec = common_speculative_init(ctx_dft);
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), 1);
const auto t_enc_end = ggml_time_us();
@@ -151,8 +152,9 @@ int main(int argc, char ** argv) {
//LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
// always have a token to evaluate from before - id_last
common_batch_clear(batch_tgt);
common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
llama_batch_ext_clear(batch_tgt);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch_tgt, id_last, n_past++, &seq_id, 1, true);
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
{
@@ -162,12 +164,12 @@ int main(int argc, char ** argv) {
}
for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true);
}
//LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
llama_decode(ctx_tgt, batch_tgt);
llama_decode_ext(ctx_tgt, batch_tgt);
}
// sample from the full target batch and return the accepted tokens based on the target sampler
@@ -217,7 +219,7 @@ int main(int argc, char ** argv) {
{
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
llama_kv_self_seq_rm(ctx_tgt, 0, n_past, -1);
}
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
@@ -253,6 +255,7 @@ int main(int argc, char ** argv) {
common_sampler_free(smpl);
common_speculative_free(spec);
llama_batch_ext_free(batch_tgt);
llama_backend_free();
LOG("\n\n");

View File

@@ -45,7 +45,6 @@ int main(int argc, char ** argv) {
}
common_init();
if (params.speculative.model.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__);
return 1;
@@ -166,9 +165,12 @@ int main(int argc, char ** argv) {
const auto t_enc_start = ggml_time_us();
// eval the prompt with both models
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1));
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1));
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input));
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true));
llama_decode_ext(ctx_tgt, batch0.get());
llama_decode_ext(ctx_tgt, batch1.get());
llama_decode_ext(ctx_dft, batch2.get());
const auto t_enc_end = ggml_time_us();
@@ -199,8 +201,8 @@ int main(int argc, char ** argv) {
drafts[s].smpl = common_sampler_init(model_dft, params.sampling);
}
llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft);
llama_batch_ext * batch_dft = llama_batch_ext_init(llama_n_batch(ctx_dft), 1);
llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), n_seq_dft);
const auto t_dec_start = ggml_time_us();
@@ -335,7 +337,7 @@ int main(int argc, char ** argv) {
if (i == s) {
continue;
}
if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
if (drafts[i].active && drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
// synchronize active status for sequences with the same drafted token
drafts[i].active = drafts[i].active && accept;
if (!drafts[i].active) {
@@ -420,14 +422,14 @@ int main(int argc, char ** argv) {
{
LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
llama_kv_cache_seq_keep(ctx_dft, s_keep);
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_dft, 0);
llama_kv_self_seq_keep(ctx_dft, s_keep);
llama_kv_self_seq_cp (ctx_dft, s_keep, 0, -1, -1);
llama_kv_self_seq_keep(ctx_dft, 0);
llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
llama_kv_cache_seq_keep(ctx_tgt, s_keep);
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_tgt, 0);
llama_kv_self_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
llama_kv_self_seq_keep(ctx_tgt, s_keep);
llama_kv_self_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_kv_self_seq_keep(ctx_tgt, 0);
}
for (int s = 0; s < n_seq_dft; ++s) {
@@ -441,12 +443,13 @@ int main(int argc, char ** argv) {
drafts[0].dists.push_back(std::vector<llama_token_data>());
drafts[0].i_batch_tgt.push_back(0);
common_batch_clear(batch_dft);
common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
llama_batch_ext_clear(batch_dft);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch_dft, token_id, n_past_dft, &seq_id, 1, true);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1);
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
llama_decode(ctx_dft, batch_dft);
llama_decode_ext(ctx_dft, batch_dft);
++n_past_dft;
}
@@ -471,12 +474,19 @@ int main(int argc, char ** argv) {
drafts[0].drafting = true;
drafts[0].i_batch_dft = 0;
common_batch_clear(batch_tgt);
common_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
struct batch_info {
llama_token id;
llama_pos pos;
std::vector<llama_seq_id> seq_id;
};
std::vector<batch_info> batch_tgt_data;
batch_tgt_data.push_back({ drafts[0].tokens[0], n_past_tgt, {0} });
// sample n_draft tokens from the draft model using tree-based sampling
for (int i = 0; i < n_draft; ++i) {
batch_dft.n_tokens = 0;
llama_batch_ext_clear(batch_dft);
for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].skip = false;
@@ -503,15 +513,14 @@ int main(int argc, char ** argv) {
if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) {
LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur);
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
llama_kv_self_seq_rm(ctx_dft, n_seq_cur, -1, -1);
llama_kv_self_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
// all previous tokens from this branch are now also part of the new branch
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
if (batch_tgt.seq_id[t][p] == s) {
batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur;
batch_tgt.n_seq_id[t]++;
for (int t = 0; t < (int) batch_tgt_data.size(); ++t) {
for (int p = 0; p < (int) batch_tgt_data[t].seq_id.size(); ++p) {
if (batch_tgt_data[t].seq_id[p] == s) {
batch_tgt_data[t].seq_id.push_back(n_seq_cur);
break;
}
}
@@ -553,45 +562,50 @@ int main(int argc, char ** argv) {
drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
// add unique drafted tokens to the target batch
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
drafts[s].i_batch_tgt.push_back(batch_tgt_data.size());
common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
batch_tgt_data.push_back({ id, n_past_tgt + i + 1, { s }});
// add the token to the batch for batched decoding with the draft model
drafts[s].i_batch_dft = batch_dft.n_tokens;
drafts[s].i_batch_dft = llama_batch_ext_add_text(batch_dft, id, n_past_cur, &s, 1, true);
common_batch_add(batch_dft, id, n_past_cur, { s }, true);
if (batch_tgt.n_tokens > n_draft) {
if (batch_tgt_data.size() > (size_t) n_draft) {
drafts[s].drafting = false;
}
}
}
// no sequence is drafting anymore
if (batch_dft.n_tokens == 0) {
if (llama_batch_ext_get_n_tokens(batch_dft) == 0) {
break;
}
// evaluate the drafted tokens on the draft model
llama_decode(ctx_dft, batch_dft);
llama_decode_ext(ctx_dft, batch_dft);
++n_past_cur;
++n_drafted;
if (batch_tgt.n_tokens > n_draft) {
if (batch_tgt_data.size() > (size_t) n_draft) {
break;
}
}
// evaluate the target model on the drafted tokens
{
llama_kv_cache_seq_keep(ctx_tgt, 0);
llama_kv_self_seq_keep(ctx_tgt, 0);
for (int s = 1; s < n_seq_dft; ++s) {
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
llama_kv_self_seq_cp(ctx_tgt, 0, s, -1, -1);
}
llama_batch_ext_clear(batch_tgt);
for (int i = 0; i < (int) batch_tgt_data.size(); ++i) {
const auto & data = batch_tgt_data[i];
llama_batch_ext_add_text(batch_tgt, data.id, data.pos, data.seq_id.data(), data.seq_id.size(), true);
}
// LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
llama_decode(ctx_tgt, batch_tgt);
llama_decode_ext(ctx_tgt, batch_tgt);
++n_past_tgt;
}
@@ -634,7 +648,8 @@ int main(int argc, char ** argv) {
common_sampler_free(drafts[s].smpl);
}
llama_batch_free(batch_dft);
llama_batch_ext_free(batch_dft);
llama_batch_ext_free(batch_tgt);
llama_backend_free();

View File

@@ -87,11 +87,11 @@ struct wav_header {
uint32_t data_size;
};
static void save_wav16(const std::string & fname, const std::vector<float> & data, int sample_rate) {
static bool save_wav16(const std::string & fname, const std::vector<float> & data, int sample_rate) {
std::ofstream file(fname, std::ios::binary);
if (!file) {
LOG_ERR("%s: Failed to open file '%s' for writing", __func__, fname.c_str());
return;
LOG_ERR("%s: Failed to open file '%s' for writing.\n", __func__, fname.c_str());
return false;
}
wav_header header;
@@ -108,7 +108,7 @@ static void save_wav16(const std::string & fname, const std::vector<float> & dat
file.write(reinterpret_cast<const char*>(&pcm_sample), sizeof(pcm_sample));
}
file.close();
return file.good();
}
static void fill_hann_window(int length, bool periodic, float * output) {
@@ -536,6 +536,7 @@ static std::string audio_data_from_speaker(json speaker, const outetts_version t
int main(int argc, char ** argv) {
common_params params;
params.out_file = "output.wav";
params.prompt = "";
params.n_predict = 4096;
@@ -817,7 +818,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
// create a llama_batch
// we use this object to submit token data for decoding
llama_batch batch = llama_batch_init(std::max(prompt_inp.size(), (size_t) n_parallel), 0, n_parallel);
llama_batch_ext * batch = llama_batch_ext_init(std::max(prompt_inp.size(), (size_t) n_parallel), n_parallel);
std::vector<llama_seq_id> seq_ids(n_parallel, 0);
for (int32_t i = 0; i < n_parallel; ++i) {
@@ -826,14 +827,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
// evaluate the initial prompt
for (size_t i = 0; i < prompt_inp.size(); ++i) {
common_batch_add(batch, prompt_inp[i], i, seq_ids, false);
llama_batch_ext_add_text(batch, prompt_inp[i], i, seq_ids.data(), seq_ids.size(), false);
}
GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size());
GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) prompt_inp.size());
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
llama_batch_ext_set_output_last(batch);
if (llama_decode(ctx_ttc, batch) != 0) {
if (llama_decode_ext(ctx_ttc, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
}
@@ -852,16 +853,16 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
// remember the batch index of the last token for each parallel sequence
// we need this to determine which logits to sample from
std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
std::vector<int32_t> i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1);
int n_past = batch.n_tokens;
int n_past = llama_batch_ext_get_n_tokens(batch);
int n_decode = 0;
bool next_token_uses_guide_token = true;
while (n_decode <= n_predict) {
// prepare the next batch
common_batch_clear(batch);
llama_batch_ext_clear(batch);
// sample the next token for each parallel sequence / stream
for (int32_t i = 0; i < n_parallel; ++i) {
@@ -917,14 +918,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
//LOG_CNT("%d", i);
}
i_batch[i] = batch.n_tokens;
i_batch[i] = llama_batch_ext_get_n_tokens(batch);
// push this new token for next evaluation
common_batch_add(batch, new_token_id, n_past, { i }, true);
llama_batch_ext_add_text(batch, new_token_id, n_past, &i, 1, true);
}
// all streams are finished
if (batch.n_tokens == 0) {
if (llama_batch_ext_get_n_tokens(batch) == 0) {
break;
}
@@ -932,13 +933,13 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
n_past += 1;
// evaluate the current batch with the transformer model
if (llama_decode(ctx_ttc, batch)) {
if (llama_decode_ext(ctx_ttc, batch)) {
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
}
llama_batch_free(batch);
llama_batch_ext_free(batch);
LOG("\n");
LOG_INF("%s: time for decoder: %.3f ms\n", __func__, (ggml_time_us() - t_dec_start) / 1000.0f);
@@ -1007,14 +1008,15 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
const int n_codes = codes.size();
llama_batch batch = llama_batch_init(n_codes, 0, 1);
llama_batch_ext * batch = llama_batch_ext_init(n_codes, 1);
for (size_t i = 0; i < codes.size(); ++i) {
common_batch_add(batch, codes[i], i, { 0 }, true); // TODO: all logits?
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, codes[i], i, &seq_id, 1, true); // TODO: all logits?
}
GGML_ASSERT(batch.n_tokens == n_codes);
GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == n_codes);
if (llama_decode(ctx_cts, batch) != 0) {
if (llama_decode_ext(ctx_cts, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
}
@@ -1060,8 +1062,6 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
}
#endif
const std::string fname = "output.wav";
const int n_sr = 24000; // sampling rate
// zero out first 0.25 seconds
@@ -1072,11 +1072,16 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
LOG_INF("%s: time for spectral ops: %.3f ms\n", __func__, (ggml_time_us() - t_spec_start) / 1000.0f);
LOG_INF("%s: total time: %.3f ms\n", __func__, (ggml_time_us() - t_main_start) / 1000.0f);
save_wav16(fname, audio, n_sr);
int retval = 0;
LOG_INF("%s: audio written to file '%s'\n", __func__, fname.c_str());
if (save_wav16(params.out_file, audio, n_sr)) {
LOG_INF("%s: audio written to file '%s'\n", __func__, params.out_file.c_str());
} else {
retval = ENOENT;
}
llama_batch_ext_free(batch);
llama_backend_free();
return 0;
return retval;
}

View File

@@ -186,6 +186,7 @@ option(GGML_OPENMP "ggml: use OpenMP"
option(GGML_RPC "ggml: use RPC" OFF)
option(GGML_SYCL "ggml: use SYCL" OFF)
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
"ggml: sycl target device")
set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING

26
ggml/cmake/common.cmake Normal file
View File

@@ -0,0 +1,26 @@
function(ggml_get_flags CCID CCVER)
set(C_FLAGS "")
set(CXX_FLAGS "")
if (CCID MATCHES "Clang")
set(C_FLAGS -Wunreachable-code-break -Wunreachable-code-return)
set(CXX_FLAGS -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi)
if (
(CCID STREQUAL "Clang" AND CCVER VERSION_GREATER_EQUAL 3.8.0) OR
(CCID STREQUAL "AppleClang" AND CCVER VERSION_GREATER_EQUAL 7.3.0)
)
list(APPEND C_FLAGS -Wdouble-promotion)
endif()
elseif (CCID STREQUAL "GNU")
set(C_FLAGS -Wdouble-promotion)
set(CXX_FLAGS -Wno-array-bounds)
if (CCVER VERSION_GREATER_EQUAL 8.1.0)
list(APPEND CXX_FLAGS -Wextra-semi)
endif()
endif()
set(GF_C_FLAGS ${C_FLAGS} PARENT_SCOPE)
set(GF_CXX_FLAGS ${CXX_FLAGS} PARENT_SCOPE)
endfunction()

View File

@@ -454,6 +454,7 @@ extern "C" {
GGML_OP_RMS_NORM,
GGML_OP_RMS_NORM_BACK,
GGML_OP_GROUP_NORM,
GGML_OP_L2_NORM,
GGML_OP_MUL_MAT,
GGML_OP_MUL_MAT_ID,
@@ -502,6 +503,7 @@ extern "C" {
GGML_OP_ADD_REL_POS,
GGML_OP_RWKV_WKV6,
GGML_OP_GATED_LINEAR_ATTN,
GGML_OP_RWKV_WKV7,
GGML_OP_UNARY,
@@ -1095,6 +1097,18 @@ extern "C" {
int n_groups,
float eps);
// l2 normalize along rows
// used in rwkv v7
GGML_API struct ggml_tensor * ggml_l2_norm(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps);
GGML_API struct ggml_tensor * ggml_l2_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps);
// a - x
// b - dy
GGML_API struct ggml_tensor * ggml_rms_norm_back(
@@ -1890,6 +1904,16 @@ extern "C" {
struct ggml_tensor * state,
float scale);
GGML_API struct ggml_tensor * ggml_rwkv_wkv7(
struct ggml_context * ctx,
struct ggml_tensor * r,
struct ggml_tensor * w,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * state);
// custom operators
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);

View File

@@ -1,4 +1,5 @@
include(CheckCXXCompilerFlag)
include("../cmake/common.cmake")
add_compile_definitions(GGML_SCHED_MAX_COPIES=${GGML_SCHED_MAX_COPIES})
@@ -24,33 +25,6 @@ if (NOT MSVC)
endif()
endif()
function(ggml_get_flags CCID CCVER)
set(C_FLAGS "")
set(CXX_FLAGS "")
if (CCID MATCHES "Clang")
set(C_FLAGS -Wunreachable-code-break -Wunreachable-code-return)
set(CXX_FLAGS -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi)
if (
(CCID STREQUAL "Clang" AND CCVER VERSION_GREATER_EQUAL 3.8.0) OR
(CCID STREQUAL "AppleClang" AND CCVER VERSION_GREATER_EQUAL 7.3.0)
)
list(APPEND C_FLAGS -Wdouble-promotion)
endif()
elseif (CCID STREQUAL "GNU")
set(C_FLAGS -Wdouble-promotion)
set(CXX_FLAGS -Wno-array-bounds)
if (CCVER VERSION_GREATER_EQUAL 8.1.0)
list(APPEND CXX_FLAGS -Wextra-semi)
endif()
endif()
set(GF_C_FLAGS ${C_FLAGS} PARENT_SCOPE)
set(GF_CXX_FLAGS ${CXX_FLAGS} PARENT_SCOPE)
endfunction()
if (GGML_FATAL_WARNINGS)
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
list(APPEND C_FLAGS -Werror)

View File

@@ -497,7 +497,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
search_paths.push_back(get_executable_path());
search_paths.push_back(fs::current_path());
} else {
search_paths.push_back(user_search_path);
search_paths.push_back(fs::u8path(user_search_path));
}
int best_score = 0;
@@ -511,9 +511,9 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
for (const auto & entry : dir_it) {
if (entry.is_regular_file()) {
auto filename = entry.path().filename().native();
auto ext = entry.path().extension().native();
if (filename.find(file_prefix) == 0 && ext == file_extension) {
auto filename = entry.path().filename();
auto ext = entry.path().extension();
if (filename.native().find(file_prefix) == 0 && ext == file_extension) {
dl_handle_ptr handle { dl_load_library(entry) };
if (!handle && !silent) {
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(entry.path()).c_str());
@@ -544,7 +544,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
// try to load the base backend
for (const auto & search_path : search_paths) {
fs::path filename = backend_filename_prefix().native() + name_path.native() + backend_filename_extension().native();
fs::path path = search_path.native() + filename.native();
fs::path path = search_path / filename;
if (fs::exists(path)) {
return get_reg().load_backend(path, silent);
}

View File

@@ -2790,10 +2790,14 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
(char*)output_buffer + batch1 * output_stride, ACL_FLOAT16,
output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND,
output_ne_offset);
int64_t antiquantGroupSize = 0;
if (src0->ne[0] > QK8_0) {
antiquantGroupSize = QK8_0;
}
ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr,
nullptr, nullptr, nullptr, QK8_0, acl_output_tensor,
nullptr, nullptr, nullptr, antiquantGroupSize, acl_output_tensor,
&workspaceSize, &executor));
if (workspaceAddr == nullptr) {
workspaceAddr = workspace_allocator.alloc(workspaceSize);
@@ -2833,7 +2837,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
acl_input_tensor, acl_weight_tensor, acl_scale_tensor,
nullptr, nullptr, nullptr, nullptr, QK8_0,
nullptr, nullptr, nullptr, nullptr, antiquantGroupSize,
acl_output_tensor, &workspaceSize, &executor));
ACL_CHECK(aclnnWeightQuantBatchMatmulV2(
workspaceAddr, workspaceSize, executor, ctx.stream()));

View File

@@ -1689,11 +1689,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_OP_MUL_MAT: {
switch (op->src[0]->type) {
case GGML_TYPE_Q8_0:
// Current groupsize should not be greater than k-1 in
// aclnnWeightQuantBatchMatmulV2GetWorkspaceSize
if (op->src[0]->ne[0] <= QK8_0) {
return false;
}
case GGML_TYPE_F16:
case GGML_TYPE_F32:
case GGML_TYPE_Q4_0:

View File

@@ -287,17 +287,25 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
endif()
endif()
endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
elseif ("${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "ppc64le " OR "${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "powerpc ")
message(STATUS "PowerPC detected")
execute_process(COMMAND bash -c "grep POWER /proc/cpuinfo | head -n 1" OUTPUT_VARIABLE POWER_M)
if (${POWER_M} MATCHES "POWER10")
list(APPEND ARCH_FLAGS -mcpu=power10)
elseif (${POWER_M} MATCHES "POWER9")
list(APPEND ARCH_FLAGS -mcpu=power9)
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
file(READ "/proc/cpuinfo" POWER10_M)
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "powerpc")
execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M)
endif()
string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M}")
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64)
elseif (EXTRACTED_NUMBER EQUAL 9)
list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64)
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
else()
list(APPEND ARCH_FLAGS -mcpu=powerpc64 -mtune=native)
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64)
endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
message(STATUS "loongarch64 detected")

View File

@@ -8158,7 +8158,156 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const int nb = n / QK_K;
#ifdef __ARM_NEON
#ifdef __ARM_FEATURE_SVE
const int vector_length = ggml_cpu_get_sve_cnt()*8;
float sum = 0;
svuint8_t m4b = svdup_n_u8(0xf);
svint32_t vzero = svdup_n_s32(0);
svuint8_t mone = svdup_n_u8(0x30);
svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
for (int i = 0; i < nb; ++i) {
const float d_all = GGML_FP16_TO_FP32(x[i].d);
const uint8_t * GGML_RESTRICT q6 = x[i].ql;
const uint8_t * GGML_RESTRICT qh = x[i].qh;
const int8_t * GGML_RESTRICT q8 = y[i].qs;
const int8_t * GGML_RESTRICT scale = x[i].scales;
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
const svint64_t prod = svdup_n_s64(0);
int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
svdot_s64(prod, q8sums_2, q6scales_2)));
int32_t isum = 0;
switch (vector_length) {
case 128:
{
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
svint32_t isum_tmp = svdup_n_s32(0);
for (int j = 0; j < QK_K/128; ++j) {
svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
qh += 32;
svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
q6 += 64;
svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
q8 += 64;
q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
scale += 4;
q8bytes_1 = svld1_s8(pg8_16, q8);
q8bytes_2 = svld1_s8(pg8_16, q8+16);
q8bytes_3 = svld1_s8(pg8_16, q8+32);
q8bytes_4 = svld1_s8(pg8_16, q8+48);
q8 += 64;
q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
scale += 4;
}
isum += svaddv_s32(pg32_4, isum_tmp);
sum += d_all * y[i].d * (isum - 32 * isum_mins);
}
break;
case 256:
case 512:
{
const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
svint32_t isum_tmp = svdup_n_s32(0);
for (int j = 0; j < QK_K/128; j++) {
svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
qh += 32;
svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
q6 += 64;
svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
q8 += 128;
q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
scale += 8;
}
isum += svaddv_s32(pg32_8, isum_tmp);
sum += d_all * y[i].d * (isum - 32 * isum_mins);
}
break;
default:
assert(false && "Unsupported vector length");
break;
}
}
*s = sum;
#elif __ARM_NEON
float sum = 0;
const uint8x16_t m4b = vdupq_n_u8(0xF);

View File

@@ -8548,6 +8548,69 @@ static void ggml_compute_forward_group_norm(
}
}
// ggml_compute_forward_l2_norm
static void ggml_compute_forward_l2_norm_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
GGML_TENSOR_UNARY_OP_LOCALS
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
ggml_float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += (ggml_float)(x[i00] * x[i00]);
}
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
memcpy(y, x, ne00 * sizeof(float));
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
ggml_vec_scale_f32(ne00, y, scale);
}
}
}
}
static void ggml_compute_forward_l2_norm(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_l2_norm_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_mul_mat
static void ggml_compute_forward_mul_mat_one_chunk(
@@ -13604,6 +13667,184 @@ static void ggml_compute_forward_gla(
}
}
// ggml_compute_forward_rwkv_wkv7
static void ggml_compute_forward_rwkv_wkv7_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const int64_t T = dst->src[1]->ne[2];
const int64_t C = dst->ne[0];
const int64_t HEADS = dst->src[1]->ne[1];
const int64_t n_seqs = dst->src[6]->ne[1];
const int64_t head_size = C / HEADS;
float * dst_data = (float *) dst->data;
float * state = ((float *) dst->data) + C * T;
const int ith = params->ith;
const int nth = params->nth;
if (ith >= HEADS) {
return;
}
const int h_start = (HEADS * ith) / nth;
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
(HEADS * (ith + 1)) / nth : HEADS;
float * r = (float *) dst->src[0]->data;
float * w = (float *) dst->src[1]->data;
float * k = (float *) dst->src[2]->data;
float * v = (float *) dst->src[3]->data;
float * a = (float *) dst->src[4]->data;
float * b = (float *) dst->src[5]->data;
int64_t t_stride = HEADS * head_size; // Same to C
int64_t h_stride = C / HEADS;
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
int64_t h_stride_2d = head_size * head_size;
#if defined(GGML_SIMD)
for (int64_t t = 0; t < T; t++) {
int64_t t_offset = t * t_stride;
int64_t state_offset = head_size * C * (t / (T / n_seqs));
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
for (int64_t h = h_start; h < h_end; h++) {
int64_t h_offset = h * h_stride;
int64_t t_h_offset = t_offset + h_offset;
int64_t h_2d_offset = h * h_stride_2d;
for (int64_t ii = 0; ii < head_size; ii++) {
int64_t t_h_i_offset = t_h_offset + ii;
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
float sa = 0;
{
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
GGML_F32_VEC ax[GGML_F32_ARR];
GGML_F32_VEC ay[GGML_F32_ARR];
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
}
}
GGML_F32_VEC_REDUCE(sa, sum);
}
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
int64_t j = 0;
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
for (; j < head_size; j += GGML_F32_STEP) {
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
// kv + s * decay + sa * b
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
}
}
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
// There shouldn't be left-overs though.
for (; j < head_size; j++) {
int64_t t_h_j_offset = t_h_offset + j;
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
float r_val = r[t_h_j_offset];
float w_val = w[t_h_j_offset];
float k_val = k[t_h_j_offset];
float b_val = b[t_h_j_offset];
float kv_val = v[t_h_i_offset] * k_val;
float prev_state_val = state_prev[h_2d_i_j_offset];
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
}
}
}
}
#else
for (int64_t t = 0; t < T; t++) {
int64_t t_offset = t * t_stride;
int64_t state_offset = head_size * C * (t / (T / n_seqs));
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
for (int64_t h = h_start; h < h_end; h++) {
int64_t h_offset = h * h_stride;
int64_t t_h_offset = t_offset + h_offset;
int64_t h_2d_offset = h * h_stride_2d;
for (int64_t i = 0; i < head_size; i++) {
int64_t t_h_i_offset = t_h_offset + i;
int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
float v_val = v[t_h_i_offset];
float sa = 0, result = 0;
for (int64_t j = 0; j < head_size; j++) {
sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
}
for (int64_t j = 0; j < head_size; j++) {
int64_t t_h_j_offset = t_h_offset + j;
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
float r_val = r[t_h_j_offset];
float w_val = w[t_h_j_offset];
float k_val = k[t_h_j_offset];
float b_val = b[t_h_j_offset];
float kv_val = v_val * k_val;
float prev_state_val = state_prev[h_2d_i_j_offset];
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
result += state_cur[h_2d_i_j_offset] * r_val;
}
dst_data[t_h_i_offset] = result;
}
}
}
#endif
}
static void ggml_compute_forward_rwkv_wkv7(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_rwkv_wkv7_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_map_unary
static void ggml_compute_forward_map_unary_f32(
@@ -14170,6 +14411,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_group_norm(params, tensor);
} break;
case GGML_OP_L2_NORM:
{
ggml_compute_forward_l2_norm(params, tensor);
} break;
case GGML_OP_MUL_MAT:
{
ggml_compute_forward_mul_mat(params, tensor);
@@ -14357,6 +14602,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_gla(params, tensor);
} break;
case GGML_OP_RWKV_WKV7:
{
ggml_compute_forward_rwkv_wkv7(params, tensor);
} break;
case GGML_OP_MAP_UNARY:
{
ggml_unary_op_f32_t fun;
@@ -14582,6 +14831,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_L2_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_CONCAT:
case GGML_OP_MUL_MAT:
@@ -14648,14 +14898,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_FLASH_ATTN_BACK:
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_RWKV_WKV7:
{
n_tasks = n_threads;
} break;
case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART:
case GGML_OP_GET_REL_POS:
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1_F32:

View File

@@ -395,11 +395,11 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
#if defined(CDNA) || defined(RDNA2) || defined(__gfx906__)
c = __builtin_amdgcn_sdot4(a, b, c, false);
#elif defined(RDNA3)
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
#elif defined(__gfx1010__) || defined(__gfx900__)
#elif defined(RDNA1) || defined(__gfx900__)
int tmp1;
int tmp2;
asm("\n \
@@ -678,7 +678,7 @@ struct ggml_tensor_extra_gpu {
};
#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS)
#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
#define USE_CUDA_GRAPH
#endif

View File

@@ -52,12 +52,11 @@ typedef half (*vec_dot_KQ_f16_t)(
typedef float (*vec_dot_KQ_f32_t)(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
template<typename T, int D>
template<typename T, int D, int warp_size>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
GGML_UNUSED(Q_v);
T sum = 0.0f;
@@ -93,12 +92,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
return sum;
}
template<typename T, int D>
template<typename T, int D, int warp_size>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
GGML_UNUSED(Q_v);
T sum = 0.0f;
@@ -138,12 +136,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
return sum;
}
template<typename T, int D>
template<typename T, int D, int warp_size>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
GGML_UNUSED(Q_v);
T sum = 0.0f;
@@ -186,12 +183,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
return sum;
}
template<typename T, int D>
template<typename T, int D, int warp_size>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
GGML_UNUSED(Q_v);
T sum = 0.0f;
@@ -238,12 +234,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
return sum;
}
template <typename T, int D>
template <typename T, int D, int warp_size>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
GGML_UNUSED(Q_v);
T sum = 0.0f;
@@ -272,12 +267,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
return sum;
}
template <typename T, int D>
template <typename T, int D, int warp_size>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
const half2 * K_h2 = (const half2 *) K_c;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);
@@ -480,25 +474,25 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v
return x[i];
}
template <int D>
template <int D, int warp_size = WARP_SIZE>
constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D, warp_size> :
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D, warp_size> :
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D, warp_size> :
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D, warp_size> :
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D, warp_size> :
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D, warp_size> :
nullptr;
}
template <int D>
template <int D, int warp_size = WARP_SIZE>
constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> :
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> :
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> :
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> :
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> :
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> :
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D, warp_size> :
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D, warp_size> :
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D, warp_size> :
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D, warp_size> :
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D, warp_size> :
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D, warp_size> :
nullptr;
}
@@ -681,7 +675,8 @@ static void on_no_fattn_vec_case(const int D) {
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
void launch_fattn(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V,
const int warp_size = WARP_SIZE
) {
constexpr int ncols = ncols1 * ncols2;
@@ -704,8 +699,6 @@ void launch_fattn(
GGML_ASSERT(Q->ne[3] == 1);
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
const int id = ggml_cuda_get_device();
@@ -805,7 +798,6 @@ void launch_fattn(
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
GGML_ASSERT(block_dim.x % warp_size == 0);
GGML_ASSERT(!GGML_CUDA_CC_IS_AMD(cc) || block_dim.x * block_dim.y <= 4 * (unsigned int)warp_size);
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
(const char *) Q->data,
K_data,

View File

@@ -469,6 +469,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
@@ -485,7 +486,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
return;
}
if (2*blocks_num_pb1 < 2*nsm) {
@@ -500,7 +501,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
return;
}
constexpr int parallel_blocks = 1;
@@ -514,7 +515,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
}
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View File

@@ -36,7 +36,7 @@
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/wkv6.cuh"
#include "ggml-cuda/wkv.cuh"
#include "ggml-cuda/gla.cuh"
#include "ggml.h"
@@ -2196,6 +2196,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_GROUP_NORM:
ggml_cuda_op_group_norm(ctx, dst);
break;
case GGML_OP_L2_NORM:
ggml_cuda_op_l2_norm(ctx, dst);
break;
case GGML_OP_CONCAT:
ggml_cuda_op_concat(ctx, dst);
break;
@@ -2304,6 +2307,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_GATED_LINEAR_ATTN:
ggml_cuda_op_gated_linear_attn(ctx, dst);
break;
case GGML_OP_RWKV_WKV7:
ggml_cuda_op_rwkv_wkv7(ctx, dst);
break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
ggml_cuda_cross_entropy_loss_back(ctx, dst);
break;
@@ -2610,13 +2616,15 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo result_info;
#ifdef __HIP_PLATFORM_AMD__
hipGraphNode_t errorNode;
hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
#else
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
#endif
#else
cudaGraphNode_t errorNode;
cudaGraphExecUpdateResult result_info;
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
#endif // CUDART_VERSION >= 12000
if (stat == cudaErrorGraphExecUpdateFailure) {
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
@@ -3159,6 +3167,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
break;
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
return true;
case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
@@ -3213,6 +3222,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_LEAKY_RELU:
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_RWKV_WKV7:
return true;
case GGML_OP_FLASH_ATTN_EXT: {
#ifndef FLASH_ATTN_AVAILABLE

View File

@@ -47,11 +47,89 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
1;
}
enum mmvq_parameter_table_id {
MMVQ_PARAMETERS_GENERIC = 0,
MMVQ_PARAMETERS_GCN,
MMVQ_PARAMETERS_RDNA2
};
static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
#if defined(RDNA2) || defined(RDNA3)
return MMVQ_PARAMETERS_RDNA2;
#elif defined(GCN) || defined(CDNA)
return MMVQ_PARAMETERS_GCN;
#else
return MMVQ_PARAMETERS_GENERIC;
#endif
}
static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
return MMVQ_PARAMETERS_RDNA2;
}
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
return MMVQ_PARAMETERS_GCN;
}
return MMVQ_PARAMETERS_GENERIC;
}
static constexpr __host__ __device__ int calc_nwarps(int ncols_y, mmvq_parameter_table_id table_id) {
if (table_id == MMVQ_PARAMETERS_GENERIC) {
switch (ncols_y) {
case 1:
case 2:
case 3:
case 4:
return 4;
case 5:
case 6:
case 7:
case 8:
return 2;
default:
return 1;
}
} else if (table_id == MMVQ_PARAMETERS_GCN) {
switch (ncols_y) {
case 1:
case 2:
case 3:
case 4:
return 2;
case 5:
case 6:
case 7:
case 8:
default:
return 1;
}
}
return 1;
}
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int table_id) {
if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
switch (ncols_y) {
case 1:
return 1;
case 2:
case 3:
case 4:
case 5:
case 6:
case 7:
case 8:
return 2;
default:
return 1;
}
}
return 1;
}
template <ggml_type type, int ncols_y>
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
// tell the compiler to use as many registers as it wants, see nwarps definition below
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@@ -59,24 +137,20 @@ static __global__ void mul_mat_vec_q(
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
constexpr int nwarps = calc_nwarps(ncols_y, table_id);
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_y, table_id);
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
constexpr int nwarps = 1;
constexpr int rows_per_cuda_block = 1;
#else
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
const int tid = warp_size*threadIdx.y + threadIdx.x;
const int row0 = rows_per_cuda_block*blockIdx.x;
const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_col_y = nrows_y / QK8_1;
constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
// partial sum for each thread
// partial sum for each thread
float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
const block_q8_1 * y = (const block_q8_1 *) vy;
@@ -96,7 +170,7 @@ static __global__ void mul_mat_vec_q(
}
}
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size];
if (threadIdx.y > 0) {
#pragma unroll
for (int j = 0; j < ncols_y; ++j) {
@@ -120,7 +194,7 @@ static __global__ void mul_mat_vec_q(
for (int l = 0; l < nwarps-1; ++l) {
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
}
tmp[j][i] = warp_reduce_sum(tmp[j][i]);
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
}
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
@@ -129,6 +203,13 @@ static __global__ void mul_mat_vec_q(
}
}
static std::pair<dim3, dim3> calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_y, table_id) - 1) / calc_rows_per_block(ncols_y, table_id);
const dim3 block_nums(nblocks, 1, 1);
const dim3 block_dims(warp_size, calc_nwarps(ncols_y, table_id), 1);
return {block_nums, block_dims};
}
template <ggml_type type>
static void mul_mat_vec_q_cuda(
const void * vx, const void * vy, float * dst,
@@ -137,65 +218,67 @@ static void mul_mat_vec_q_cuda(
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
int id = ggml_cuda_get_device();
int64_t nwarps = 1;
int64_t rows_per_cuda_block = 1;
if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
switch(ncols_y) {
case 1:
nwarps = 4;
rows_per_cuda_block = 1;
break;
case 2:
case 3:
case 4:
nwarps = 4;
rows_per_cuda_block = 2;
break;
case 5:
case 6:
case 7:
case 8:
nwarps = 2;
rows_per_cuda_block = 2;
break;
default:
GGML_ABORT("fatal error");
break;
}
}
const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
const dim3 block_nums(nblocks, 1, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size;
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
switch (ncols_y) {
case 1:
mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
{
constexpr int c_ncols_y = 1;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
}
case 2:
mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
{
constexpr int c_ncols_y = 2;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
}
case 3:
mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
{
constexpr int c_ncols_y = 3;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
}
case 4:
mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
{
constexpr int c_ncols_y = 4;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
}
case 5:
mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
{
constexpr int c_ncols_y = 5;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
}
case 6:
mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
{
constexpr int c_ncols_y = 6;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
}
case 7:
mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
{
constexpr int c_ncols_y = 7;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
}
case 8:
mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
{
constexpr int c_ncols_y = 8;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
}
default:
GGML_ABORT("fatal error");
break;

View File

@@ -201,6 +201,85 @@ static __global__ void rms_norm_back_f32(
}
}
// template <int block_size>
// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
// const int tid = threadIdx.x;
// float tmp = 0.0f; // partial sum for thread in warp
// for (int col = tid; col < ncols; col += block_size) {
// const float xi = x[row*ncols + col];
// tmp += xi * xi;
// }
// // sum up partial sums
// tmp = warp_reduce_sum(tmp);
// if (block_size > WARP_SIZE) {
// __shared__ float s_sum[32];
// int warp_id = threadIdx.x / WARP_SIZE;
// int lane_id = threadIdx.x % WARP_SIZE;
// if (lane_id == 0) {
// s_sum[warp_id] = tmp;
// }
// __syncthreads();
// tmp = s_sum[lane_id];
// tmp = warp_reduce_sum(tmp);
// }
// // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
// const float scale = rsqrtf(fmaxf(tmp, eps * eps));
// for (int col = tid; col < ncols; col += block_size) {
// dst[row*ncols + col] = scale * x[row*ncols + col];
// }
// }
template <int block_size>
static __global__ void l2_norm_f32(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x;
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
tmp += xi * xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
// from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * x[col];
}
}
static void norm_f32_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
@@ -248,6 +327,19 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
}
}
static void l2_norm_f32_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
l2_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
@@ -340,3 +432,27 @@ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
}
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_TENSOR_UNARY_OP_LOCALS;
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(nb00 == ts0);
const int64_t s01 = nb01 / ts0;
const int64_t s02 = nb02 / ts0;
const int64_t s03 = nb03 / ts0;
l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
}

View File

@@ -7,3 +7,5 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -112,7 +112,7 @@
#define cudaGraphExecDestroy hipGraphExecDestroy
#define cudaGraphLaunch hipGraphLaunch
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
#define cudaGraphNodeType hipGraphNodeType
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
#define cudaGraphInstantiate hipGraphInstantiate

View File

@@ -119,7 +119,7 @@
#define cudaGraphExecDestroy musaGraphExecDestroy
#define cudaGraphExec_t musaGraphExec_t
#define cudaGraphExecUpdate musaGraphExecUpdate
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
#define cudaGraphGetNodes musaGraphGetNodes
#define cudaGraphInstantiate musaGraphInstantiate
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
@@ -132,6 +132,7 @@
#define cudaGraph_t musaGraph_t
#define cudaKernelNodeParams musaKernelNodeParams
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
#define cudaStreamBeginCapture musaStreamBeginCapture
#define cudaStreamEndCapture musaStreamEndCapture
typedef mt_bfloat16 nv_bfloat16;

199
ggml/src/ggml-cuda/wkv.cu Normal file
View File

@@ -0,0 +1,199 @@
#include "common.cuh"
#include "wkv.cuh"
template <int block_size>
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int head_size = block_size;
const int batch_i = bid / H;
const int head_i = bid % H;
const int state_size = C * head_size;
const int n_seq_tokens = T / B;
float state[head_size];
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
#pragma unroll
for (int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
}
__syncthreads();
_tf[tid] = tf[head_i * head_size + tid];
__syncthreads();
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
__syncthreads();
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
__syncthreads();
const float _v = v[t];
float y = 0;
for (int j = 0; j < head_size; j += 4) {
const float4& k = (float4&)(_k[j]);
const float4& r = (float4&)(_r[j]);
const float4& tf = (float4&)(_tf[j]);
const float4& td = (float4&)(_td[j]);
float4& s = (float4&)(state[j]);
float4 kv;
kv.x = k.x * _v;
kv.y = k.y * _v;
kv.z = k.z * _v;
kv.w = k.w * _v;
y += r.x * (tf.x * kv.x + s.x);
y += r.y * (tf.y * kv.y + s.y);
y += r.z * (tf.z * kv.z + s.z);
y += r.w * (tf.w * kv.w + s.w);
s.x = s.x * td.x + kv.x;
s.y = s.y * td.y + kv.y;
s.z = s.z * td.z + kv.z;
s.w = s.w * td.w + kv.w;
}
dst[t] = y;
}
#pragma unroll
for (int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
}
}
template <int block_size>
static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int head_size = block_size;
const int batch_i = bid / H;
const int head_i = bid % H;
const int state_size = C * head_size;
const int n_seq_tokens = T / B;
float state[head_size];
__shared__ float _r[head_size], _w[head_size], _k[head_size], _a[head_size], _b[head_size];
#ifndef GGML_USE_MUSA
#pragma unroll
#endif
for (int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
}
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
__syncthreads();
_r[tid] = r[t];
_w[tid] = w[t];
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];
__syncthreads();
float sa = 0;
#pragma unroll
for (int j = 0; j < head_size; j += 4)
{
const float4& a = (float4&)(_a[j]);
const float4& s = (float4&)(state[j]);
sa += a.x * s.x;
sa += a.y * s.y;
sa += a.z * s.z;
sa += a.w * s.w;
}
const float _v = v[t];
float y = 0;
for (int j = 0; j < head_size; j += 4) {
const float4& r = (float4&)(_r[j]);
const float4& w = (float4&)(_w[j]);
const float4& k = (float4&)(_k[j]);
const float4& b = (float4&)(_b[j]);
float4& s = (float4&)(state[j]);
float4 kv;
kv.x = k.x * _v;
kv.y = k.y * _v;
kv.z = k.z * _v;
kv.w = k.w * _v;
s.x = s.x * w.x + kv.x + sa * b.x;
s.y = s.y * w.y + kv.y + sa * b.y;
s.z = s.z * w.z + kv.z + sa * b.z;
s.w = s.w * w.w + kv.w + sa * b.w;
y += s.x * r.x;
y += s.y * r.y;
y += s.z * r.z;
y += s.w * r.w;
}
dst[t] = y;
}
#pragma unroll
for (int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
}
}
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const float * k_d = (const float *)dst->src[0]->data;
const float * v_d = (const float *)dst->src[1]->data;
const float * r_d = (const float *)dst->src[2]->data;
const float * tf_d = (const float *)dst->src[3]->data;
const float * td_d = (const float *)dst->src[4]->data;
const float * s_d = (const float *)dst->src[5]->data;
const int64_t B = dst->src[5]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
if (C / H == CUDA_WKV_BLOCK_SIZE) {
rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
} else {
rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
}
}
void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const float * r_d = (const float *)dst->src[0]->data;
const float * w_d = (const float *)dst->src[1]->data;
const float * k_d = (const float *)dst->src[2]->data;
const float * v_d = (const float *)dst->src[3]->data;
const float * a_d = (const float *)dst->src[4]->data;
const float * b_d = (const float *)dst->src[5]->data;
const float * s_d = (const float *)dst->src[6]->data;
const int64_t B = dst->src[6]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
if (C / H == CUDA_WKV_BLOCK_SIZE) {
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
} else {
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
}
}

View File

@@ -3,3 +3,5 @@
#define CUDA_WKV_BLOCK_SIZE 64
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -1,89 +0,0 @@
#include "common.cuh"
#include "wkv6.cuh"
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int head_size = CUDA_WKV_BLOCK_SIZE;
const int batch_i = bid / H;
const int head_i = bid % H;
const int state_size = C * head_size;
const int n_seq_tokens = T / B;
float state[head_size];
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
#pragma unroll
for (int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
}
__syncthreads();
_tf[tid] = tf[head_i * head_size + tid];
__syncthreads();
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
__syncthreads();
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
__syncthreads();
const float _v = v[t];
float y = 0;
for (int j = 0; j < head_size; j += 4) {
const float4& k = (float4&)(_k[j]);
const float4& r = (float4&)(_r[j]);
const float4& tf = (float4&)(_tf[j]);
const float4& td = (float4&)(_td[j]);
float4& s = (float4&)(state[j]);
float4 kv;
kv.x = k.x * _v;
kv.y = k.y * _v;
kv.z = k.z * _v;
kv.w = k.w * _v;
y += r.x * (tf.x * kv.x + s.x);
y += r.y * (tf.y * kv.y + s.y);
y += r.z * (tf.z * kv.z + s.z);
y += r.w * (tf.w * kv.w + s.w);
s.x = s.x * td.x + kv.x;
s.y = s.y * td.y + kv.y;
s.z = s.z * td.z + kv.z;
s.w = s.w * td.w + kv.w;
}
dst[t] = y;
}
#pragma unroll
for (int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
}
}
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const float * k_d = (const float *)dst->src[0]->data;
const float * v_d = (const float *)dst->src[1]->data;
const float * r_d = (const float *)dst->src[2]->data;
const float * tf_d = (const float *)dst->src[3]->data;
const float * td_d = (const float *)dst->src[4]->data;
const float * s_d = (const float *)dst->src[5]->data;
const int64_t B = dst->src[5]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
}

View File

@@ -285,6 +285,13 @@ typedef struct {
float eps;
} ggml_metal_kargs_rms_norm;
typedef struct {
int32_t ne00;
int32_t ne00_4;
uint64_t nb01;
float eps;
} ggml_metal_kargs_l2_norm;
typedef struct {
int64_t ne00;
int64_t ne01;

View File

@@ -46,6 +46,7 @@ static struct ggml_backend_device g_ggml_backend_metal_device;
static struct ggml_backend_metal_device_context {
id<MTLDevice> mtl_device;
int mtl_device_ref_count;
id<MTLLibrary> mtl_library;
bool has_simdgroup_reduction;
bool has_simdgroup_mm;
@@ -57,6 +58,7 @@ static struct ggml_backend_metal_device_context {
} g_ggml_ctx_dev_main = {
/*.mtl_device =*/ nil,
/*.mtl_device_ref_count =*/ 0,
/*.mtl_library =*/ nil,
/*.has_simdgroup_reduction =*/ false,
/*.has_simdgroup_mm =*/ false,
/*.has_residency_sets =*/ false,
@@ -108,6 +110,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
ctx->mtl_device_ref_count--;
if (ctx->mtl_device_ref_count == 0) {
if (ctx->mtl_library) {
[ctx->mtl_library release];
ctx->mtl_library = nil;
}
if (ctx->mtl_device) {
[ctx->mtl_device release];
ctx->mtl_device = nil;
@@ -177,10 +184,13 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
GGML_METAL_KERNEL_TYPE_RMS_NORM,
GGML_METAL_KERNEL_TYPE_L2_NORM,
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
GGML_METAL_KERNEL_TYPE_NORM,
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
@@ -495,6 +505,139 @@ static void * ggml_metal_host_malloc(size_t n) {
return data;
}
// load library
//
// - first check if the library is embedded
// - then check if the library is in the bundle
// - if not found, load the source and compile it
// - if that fails, return NULL
static id<MTLLibrary> ggml_metal_load_library(id<MTLDevice> device, bool use_bfloat) {
id<MTLLibrary> metal_library = nil;
NSError * error = nil;
NSString * src = nil;
#if GGML_METAL_EMBED_LIBRARY
GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
extern const char ggml_metallib_start[];
extern const char ggml_metallib_end[];
src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
#else
#ifdef SWIFT_PACKAGE
NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
#else
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
#endif
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
if (path_lib == nil) {
// Try to find the resource in the directory where the current binary located.
NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
NSString * bin_dir = [current_binary stringByDeletingLastPathComponent];
NSString * default_metallib_path = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
if ([[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
GGML_LOG_INFO("%s: found '%s'\n", __func__, [default_metallib_path UTF8String]);
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:default_metallib_path error:&error];
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
// Optionally, if this is a symlink, try to resolve it.
default_metallib_path = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:default_metallib_path error:&error];
if (default_metallib_path && [default_metallib_path length] > 0 && ![[default_metallib_path substringToIndex:1] isEqualToString:@"/"]) {
// It is a relative path, adding the binary directory as directory prefix.
default_metallib_path = [NSString pathWithComponents:@[bin_dir, default_metallib_path]];
}
if (!default_metallib_path || ![[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
// Link to the resource could not be resolved.
default_metallib_path = nil;
} else {
GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [default_metallib_path UTF8String]);
}
}
} else {
// The resource couldn't be found in the binary's directory.
default_metallib_path = nil;
}
path_lib = default_metallib_path;
}
if (path_lib != nil) {
// pre-compiled library found
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
metal_library = [device newLibraryWithURL:libURL error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return NULL;
}
} else {
GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
NSString * path_source;
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
if (path_resource) {
path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
} else {
path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
}
if (path_source == nil) {
GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
path_source = @"ggml-metal.metal";
}
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return NULL;
}
}
#endif
if (!metal_library) {
@autoreleasepool {
// dictionary of preprocessor macros
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
if (use_bfloat) {
[prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"];
}
#if GGML_METAL_EMBED_LIBRARY
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
#endif
MTLCompileOptions * options = [MTLCompileOptions new];
options.preprocessorMacros = prep;
//[options setFastMathEnabled:false];
metal_library = [device newLibraryWithSource:src options:options error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return NULL;
}
#if !__has_feature(objc_arc)
[options release];
#endif
}
}
#if GGML_METAL_EMBED_LIBRARY
[src release];
#endif // GGML_METAL_EMBED_LIBRARY
return metal_library;
}
static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) {
GGML_LOG_INFO("%s: allocating\n", __func__);
@@ -522,136 +665,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
id<MTLLibrary> metal_library = nil;
// load library
//
// - first check if the library is embedded
// - then check if the library is in the bundle
// - if not found, load the source and compile it
// - if that fails, return NULL
{
NSError * error = nil;
NSString * src = nil;
#if GGML_METAL_EMBED_LIBRARY
GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
extern const char ggml_metallib_start[];
extern const char ggml_metallib_end[];
src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
#else
#ifdef SWIFT_PACKAGE
NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
#else
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
#endif
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
if (path_lib == nil) {
// Try to find the resource in the directory where the current binary located.
NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
NSString * bin_dir = [current_binary stringByDeletingLastPathComponent];
NSString * default_metallib_path = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
if ([[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
GGML_LOG_INFO("%s: found '%s'\n", __func__, [default_metallib_path UTF8String]);
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:default_metallib_path error:&error];
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
// Optionally, if this is a symlink, try to resolve it.
default_metallib_path = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:default_metallib_path error:&error];
if (default_metallib_path && [default_metallib_path length] > 0 && ![[default_metallib_path substringToIndex:1] isEqualToString:@"/"]) {
// It is a relative path, adding the binary directory as directory prefix.
default_metallib_path = [NSString pathWithComponents:@[bin_dir, default_metallib_path]];
}
if (!default_metallib_path || ![[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
// Link to the resource could not be resolved.
default_metallib_path = nil;
} else {
GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [default_metallib_path UTF8String]);
}
}
} else {
// The resource couldn't be found in the binary's directory.
default_metallib_path = nil;
}
path_lib = default_metallib_path;
}
if (path_lib != nil) {
// pre-compiled library found
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
metal_library = [device newLibraryWithURL:libURL error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return NULL;
}
} else {
GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
NSString * path_source;
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
if (path_resource) {
path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
} else {
path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
}
if (path_source == nil) {
GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
path_source = @"ggml-metal.metal";
}
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return NULL;
}
}
#endif
if (!metal_library) {
@autoreleasepool {
// dictionary of preprocessor macros
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
if (ctx_dev->use_bfloat) {
[prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"];
}
#if GGML_METAL_EMBED_LIBRARY
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
#endif
MTLCompileOptions * options = [MTLCompileOptions new];
options.preprocessorMacros = prep;
//[options setFastMathEnabled:false];
metal_library = [device newLibraryWithSource:src options:options error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return NULL;
}
#if !__has_feature(objc_arc)
[options release];
#endif
}
}
#if GGML_METAL_EMBED_LIBRARY
[src release];
#endif // GGML_METAL_EMBED_LIBRARY
if (ctx_dev->mtl_library == nil) {
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
}
id<MTLLibrary> metal_library = ctx_dev->mtl_library;
if (metal_library == nil) {
GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
return NULL;
}
// print MTL GPU family:
@@ -725,7 +746,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
[metal_function release]; \
if (error) { \
GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
[metal_library release]; \
return NULL; \
} \
} else { \
@@ -793,10 +813,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
@@ -1044,8 +1067,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
}
[metal_library release];
return ctx;
}
@@ -1236,6 +1257,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_ARGMAX:
return true;
@@ -1273,6 +1295,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
return true;
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
@@ -2201,6 +2225,83 @@ static void ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_RWKV_WKV6:
{
const int64_t B = dst->src[5]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == 64);
size_t offs_src3 = 0;
size_t offs_src4 = 0;
size_t offs_src5 = 0;
id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
[encoder setBytes:&B length:sizeof(B) atIndex:7];
[encoder setBytes:&T length:sizeof(T) atIndex:8];
[encoder setBytes:&C length:sizeof(C) atIndex:9];
[encoder setBytes:&H length:sizeof(H) atIndex:10];
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
} break;
case GGML_OP_RWKV_WKV7:
{
const int64_t B = dst->src[6]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == 64);
size_t offs_src3 = 0;
size_t offs_src4 = 0;
size_t offs_src5 = 0;
size_t offs_src6 = 0;
id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
id<MTLBuffer> id_src6 = dst->src[6] ? ggml_metal_get_buffer(dst->src[6], &offs_src6) : nil;
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
[encoder setBytes:&B length:sizeof(B) atIndex:8];
[encoder setBytes:&T length:sizeof(T) atIndex:9];
[encoder setBytes:&C length:sizeof(C) atIndex:10];
[encoder setBytes:&H length:sizeof(H) atIndex:11];
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
} break;
case GGML_OP_MUL_MAT:
{
GGML_ASSERT(ne00 == ne10);
@@ -3107,6 +3208,42 @@ static void ggml_metal_encode_node(
const int64_t nrows = ggml_nrows(src0);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_L2_NORM:
{
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ggml_is_contiguous_1(src0));
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline;
int nth = 32; // SIMD width
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
nth *= 2;
}
nth = MIN(nth, ne00/4);
ggml_metal_kargs_l2_norm args = {
/*.ne00 =*/ ne00,
/*.ne00_4 =*/ ne00/4,
/*.nb01 =*/ nb01,
/*.eps =*/ eps,
};
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(src0);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_GROUP_NORM:

View File

@@ -1295,6 +1295,184 @@ kernel void kernel_ssm_scan_f32(
}
}
kernel void kernel_rwkv_wkv6_f32(
device const float * k,
device const float * v,
device const float * r,
device const float * tf,
device const float * td,
device const float * state_in,
device float * dst,
constant uint & B,
constant uint & T,
constant uint & C,
constant uint & H,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const uint head_size = 64; // TODO: support head_size = 128
const uint batch_id = tgpig.x / H;
const uint head_id = tgpig.x % H;
const uint tid = tpitg.x;
if (batch_id >= B || head_id >= H) {
return;
}
const uint state_size = C * head_size;
const uint n_seq_tokens = T / B;
threadgroup float _k[head_size];
threadgroup float _r[head_size];
threadgroup float _tf[head_size];
threadgroup float _td[head_size];
float state[head_size];
for (uint i = 0; i < head_size; i++) {
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ i * head_size + tid];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
_tf[tid] = tf[head_id * head_size + tid];
threadgroup_barrier(mem_flags::mem_threadgroup);
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
for (uint t = start_t; t < end_t; t += C) {
threadgroup_barrier(mem_flags::mem_threadgroup);
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
threadgroup_barrier(mem_flags::mem_threadgroup);
const float v_val = v[t];
float y = 0.0;
for (uint j = 0; j < head_size; j += 4) {
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
float4 kv = k_vec * v_val;
float4 temp = tf_vec * kv + s_vec;
y += dot(r_vec, temp);
s_vec = s_vec * td_vec + kv;
state[j] = s_vec[0];
state[j+1] = s_vec[1];
state[j+2] = s_vec[2];
state[j+3] = s_vec[3];
}
dst[t] = y;
}
for (uint i = 0; i < head_size; i++) {
dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ i * head_size + tid] = state[i];
}
}
kernel void kernel_rwkv_wkv7_f32(
device const float * r,
device const float * w,
device const float * k,
device const float * v,
device const float * a,
device const float * b,
device const float * state_in,
device float * dst,
constant uint & B,
constant uint & T,
constant uint & C,
constant uint & H,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const uint head_size = 64; // TODO: support head_size = 128
const uint batch_id = tgpig.x / H;
const uint head_id = tgpig.x % H;
const uint tid = tpitg.x;
if (batch_id >= B || head_id >= H) {
return;
}
const uint state_size = C * head_size;
const uint n_seq_tokens = T / B;
threadgroup float _r[head_size];
threadgroup float _w[head_size];
threadgroup float _k[head_size];
threadgroup float _a[head_size];
threadgroup float _b[head_size];
float state[head_size];
for (uint i = 0; i < head_size; i++) {
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ tid * head_size + i];
}
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
for (uint t = start_t; t < end_t; t += C) {
threadgroup_barrier(mem_flags::mem_threadgroup);
_r[tid] = r[t];
_w[tid] = w[t];
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];
threadgroup_barrier(mem_flags::mem_threadgroup);
const float v_val = v[t];
float y = 0.0, sa = 0.0;
float4 sa_vec(0.0);
for (int j = 0; j < head_size; j += 4) {
float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
sa_vec += a_vec * s_vec;
}
sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
for (uint j = 0; j < head_size; j += 4) {
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
float4 kv = k_vec * v_val;
s_vec = s_vec * w_vec + kv + sa * b_vec;
y += dot(s_vec, r_vec);
state[j] = s_vec[0];
state[j+1] = s_vec[1];
state[j+2] = s_vec[2];
state[j+3] = s_vec[3];
}
dst[t] = y;
}
for (uint i = 0; i < head_size; i++) {
dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ tid * head_size + i] = state[i];
}
}
kernel void kernel_argmax(
device const void * x,
device int32_t * dst,
@@ -1463,6 +1641,49 @@ kernel void kernel_rms_norm(
}
}
kernel void kernel_l2_norm(
constant ggml_metal_kargs_l2_norm & args,
device const char * src0,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort ntg[[threads_per_threadgroup]]) {
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
float sumf = 0.0f;
// parallel sum
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float scale = 1.0f/sqrt(max(sumf, args.eps));
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
y[i00] = x[i00] * scale;
}
}
kernel void kernel_group_norm(
device const float * src0,
device float * dst,

View File

@@ -67,10 +67,6 @@ if (MUSAToolkit_FOUND)
add_compile_definitions(GGML_USE_MUSA)
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
if (GGML_CUDA_GRAPHS)
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
endif()
if (GGML_CUDA_FORCE_MMQ)
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
endif()

View File

@@ -66,6 +66,9 @@ if (WIN32)
find_package(MKL REQUIRED)
target_link_libraries(ggml-sycl PRIVATE IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
else()
if (GGML_SYCL_GRAPH)
add_compile_definitions(GGML_SYCL_GRAPH)
endif()
if (GGML_SYCL_TARGET STREQUAL "INTEL")
target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")

View File

@@ -26,7 +26,7 @@
#include "softmax.hpp"
#include "tsembd.hpp"
#include "im2col.hpp"
#include "wkv6.hpp"
#include "wkv.hpp"
#include "outprod.hpp"
#include "element_wise.hpp"
#include "cpy.hpp"

View File

@@ -301,6 +301,7 @@ inline optimize_feature check_gpu_optimize_feature(syclex::architecture &arch) {
return opt;
}
namespace sycl_ex = sycl::ext::oneapi::experimental;
struct ggml_backend_sycl_context {
int device;
std::string name;
@@ -392,6 +393,10 @@ struct ggml_backend_sycl_context {
return pool(device);
}
#ifdef GGML_SYCL_GRAPH
std::unique_ptr<sycl_ex::command_graph<sycl_ex::graph_state::executable>> exec_graph = nullptr;
#endif
ggml_sycl_pool & host_pool(int device) {
if (host_pools[device] == nullptr) {
host_pools[device] = new_pool_for_host(stream(device, 0), device);
@@ -474,6 +479,7 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3,
/*int s00,*/ int s01, int s02, int s03,
/*int s10,*/ int s11, int s12, int s13,
const sycl::nd_item<3> &item_ct1) {
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
@@ -495,9 +501,9 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
const int i12 = i2 % ne12;
const int i13 = i3 % ne13;
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i_src0;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const src0_t * src0_row = src0 + i_src0;
const src1_t * src1_row = src1 + i_src1;
@@ -515,6 +521,7 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3,
/*int s00,*/ int s01, int s02, int s03,
/*int s10,*/ int s11, int s12, int s13,
const sycl::nd_item<3> &item_ct1) {
@@ -534,9 +541,9 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
const int i12 = i2 % ne12;
const int i13 = i3 % ne13;
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i_src0;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const src0_t * src0_row = src0 + i_src0;
const src1_t * src1_row = src1 + i_src1;
@@ -566,9 +573,11 @@ struct bin_bcast_sycl {
int nr[4] = { nr0, nr1, nr2, nr3 };
// collapse dimensions until first broadcast dimension
int64_t cne0[] = {ne0, ne1, ne2, ne3};
int64_t cne[] = {ne0, ne1, ne2, ne3};
int64_t cne0[] = {ne00, ne01, ne02, ne03};
int64_t cne1[] = {ne10, ne11, ne12, ne13};
size_t cnb0[] = {nb0, nb1, nb2, nb3};
size_t cnb[] = {nb0, nb1, nb2, nb3};
size_t cnb0[] = {nb00, nb01, nb02, nb03};
size_t cnb1[] = {nb10, nb11, nb12, nb13};
auto collapse = [](int64_t cne[]) {
cne[0] *= cne[1];
@@ -583,32 +592,41 @@ struct bin_bcast_sycl {
cnb[3] *= cne[3];
};
for (int i = 0; i < 4; i++) {
if (nr[i] != 1) {
break;
}
if (i > 0) {
collapse_nb(cnb0, cne0);
collapse_nb(cnb1, cne1);
collapse(cne0);
collapse(cne1);
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
for (int i = 0; i < 4; i++) {
if (nr[i] != 1) {
break;
}
if (i > 0) {
collapse_nb(cnb, cne);
collapse_nb(cnb0, cne0);
collapse_nb(cnb1, cne1);
collapse(cne);
collapse(cne0);
collapse(cne1);
}
}
}
{
int64_t ne0 = cne0[0];
int64_t ne1 = cne0[1];
int64_t ne2 = cne0[2];
int64_t ne3 = cne0[3];
int64_t ne0 = cne[0];
int64_t ne1 = cne[1];
int64_t ne2 = cne[2];
int64_t ne3 = cne[3];
int64_t ne10 = cne1[0];
int64_t ne11 = cne1[1];
int64_t ne12 = cne1[2];
int64_t ne13 = cne1[3];
size_t nb0 = cnb0[0];
size_t nb1 = cnb0[1];
size_t nb2 = cnb0[2];
size_t nb3 = cnb0[3];
size_t nb0 = cnb[0];
size_t nb1 = cnb[1];
size_t nb2 = cnb[2];
size_t nb3 = cnb[3];
size_t nb00 = cnb0[0];
size_t nb01 = cnb0[1];
size_t nb02 = cnb0[2];
size_t nb03 = cnb0[3];
size_t nb10 = cnb1[0];
size_t nb11 = cnb1[1];
@@ -625,6 +643,28 @@ struct bin_bcast_sycl {
size_t s12 = nb12 / sizeof(src1_t);
size_t s13 = nb13 / sizeof(src1_t);
size_t s00 = nb00 / sizeof(src0_t);
size_t s01 = nb01 / sizeof(src0_t);
size_t s02 = nb02 / sizeof(src0_t);
size_t s03 = nb03 / sizeof(src0_t);
GGML_UNUSED(s00);
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
GGML_ASSERT(s0 == 1);
GGML_ASSERT(s10 == 1);
@@ -661,8 +701,8 @@ struct bin_bcast_sycl {
[=](sycl::nd_item<3> item_ct1) {
k_bin_bcast_unravel<bin_op>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
s13, item_ct1);
ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
s03, s11, s12, s13, item_ct1);
});
}
} else {
@@ -680,7 +720,7 @@ struct bin_bcast_sycl {
[=](sycl::nd_item<3> item_ct1) {
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
ne2, ne3, ne10, ne11, ne12, ne13,
s1, s2, s3, s11, s12, s13,
s1, s2, s3, s01, s02, s03, s11, s12, s13,
item_ct1);
});
}

View File

@@ -138,7 +138,7 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
sycl::range<3>(1, 1, WARP_SIZE),
sycl::range<3>(1, 1, WARP_SIZE)),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
});

View File

@@ -210,7 +210,7 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
nrows, item_ct1);
});
@@ -879,7 +879,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloa
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
vx, y, dst, ncols, nrows, item_ct1);
});
@@ -902,7 +902,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
vx, y, dst, ncols, nrows, item_ct1);
});
@@ -923,7 +923,7 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
vx, y, dst, ncols, nrows, item_ct1);
});
@@ -944,7 +944,7 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
vx, y, dst, ncols, nrows, item_ct1);
});
@@ -965,7 +965,7 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
vx, y, dst, ncols, nrows, item_ct1);
});
@@ -986,7 +986,7 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
vx, y, dst, ncols, nrows, item_ct1);
});
@@ -1004,7 +1004,7 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
});
}
@@ -1020,7 +1020,7 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
});
}
@@ -1036,7 +1036,7 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
});
}
@@ -1049,7 +1049,7 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
});
}
@@ -1065,7 +1065,7 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
});
}
@@ -1143,7 +1143,6 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
default:
printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
GGML_ABORT("fatal error");
break;
}
GGML_UNUSED(src1);

View File

@@ -1,7 +1,7 @@
#include "common.hpp"
#include "element_wise.hpp"
void acc_f32(const float * x, const float * y, float * dst, const int ne,
static void acc_f32(const float * x, const float * y, float * dst, const int ne,
const int ne10, const int ne11, const int ne12,
const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
@@ -20,7 +20,7 @@ void acc_f32(const float * x, const float * y, float * dst, const int ne,
}
}
void gelu_f32(const float * x, float * dst, const int k,
static void gelu_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
@@ -37,7 +37,7 @@ void gelu_f32(const float * x, float * dst, const int k,
sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi)));
}
void silu_f32(const float * x, float * dst, const int k,
static void silu_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
@@ -48,7 +48,7 @@ void silu_f32(const float * x, float * dst, const int k,
dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i]));
}
void gelu_quick_f32(const float *x, float *dst, int k,
static void gelu_quick_f32(const float *x, float *dst, int k,
const sycl::nd_item<3> &item_ct1) {
const float GELU_QUICK_COEF = -1.702f;
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
@@ -59,7 +59,7 @@ void gelu_quick_f32(const float *x, float *dst, int k,
dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i])));
}
void tanh_f32(const float *x, float *dst, int k,
static void tanh_f32(const float *x, float *dst, int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
@@ -69,7 +69,7 @@ void tanh_f32(const float *x, float *dst, int k,
dst[i] = sycl::tanh((float)(x[i]));
}
void relu_f32(const float * x, float * dst, const int k,
static void relu_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
@@ -80,7 +80,7 @@ void relu_f32(const float * x, float * dst, const int k,
dst[i] = sycl::fmax((float)(x[i]), (float)0);
}
void sigmoid_f32(const float * x, float * dst, const int k,
static void sigmoid_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
@@ -91,7 +91,7 @@ void sigmoid_f32(const float * x, float * dst, const int k,
dst[i] = 1.0f / (1.0f + sycl::native::exp(-x[i]));
}
void sqrt_f32(const float * x, float * dst, const int k,
static void sqrt_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
@@ -102,7 +102,7 @@ void sqrt_f32(const float * x, float * dst, const int k,
dst[i] = sycl::sqrt(x[i]);
}
void sin_f32(const float * x, float * dst, const int k,
static void sin_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
@@ -113,7 +113,7 @@ void sin_f32(const float * x, float * dst, const int k,
dst[i] = sycl::sin(x[i]);
}
void cos_f32(const float * x, float * dst, const int k,
static void cos_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
@@ -124,7 +124,7 @@ void cos_f32(const float * x, float * dst, const int k,
dst[i] = sycl::cos(x[i]);
}
void hardsigmoid_f32(const float * x, float * dst, const int k,
static void hardsigmoid_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
@@ -135,7 +135,7 @@ void hardsigmoid_f32(const float * x, float * dst, const int k,
dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
}
void hardswish_f32(const float * x, float * dst, const int k,
static void hardswish_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
@@ -146,7 +146,7 @@ void hardswish_f32(const float * x, float * dst, const int k,
dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
}
void exp_f32(const float * x, float * dst, const int k,
static void exp_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
@@ -157,7 +157,7 @@ void exp_f32(const float * x, float * dst, const int k,
dst[i] = sycl::exp(x[i]);
}
void log_f32(const float * x, float * dst, const int k,
static void log_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
@@ -173,7 +173,7 @@ void log_f32(const float * x, float * dst, const int k,
}
}
void neg_f32(const float * x, float * dst, const int k,
static void neg_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
@@ -184,7 +184,7 @@ void neg_f32(const float * x, float * dst, const int k,
dst[i] = -x[i];
}
void step_f32(const float * x, float * dst, const int k,
static void step_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
@@ -195,7 +195,7 @@ void step_f32(const float * x, float * dst, const int k,
dst[i] = x[i] > 0.0f;
}
void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
@@ -206,7 +206,7 @@ void leaky_relu_f32(const float *x, float *dst, const int k, const float negativ
sycl::fmin((float)(x[i]), 0.0f) * negative_slope;
}
void sqr_f32(const float * x, float * dst, const int k,
static void sqr_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
@@ -217,7 +217,7 @@ void sqr_f32(const float * x, float * dst, const int k,
dst[i] = x[i] * x[i];
}
void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
const int nb02, const int nb03, const int ne10, const int ne11,
const int ne12, const int ne13, const float sf0, const float sf1,
const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
@@ -240,7 +240,7 @@ void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
dst[index] = *(const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
}
void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
const sycl::nd_item<3> &item_ct1) {
int nidx = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
@@ -262,7 +262,7 @@ void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const i
void acc_f32_sycl(const float *x, const float *y, float *dst,
static void acc_f32_sycl(const float *x, const float *y, float *dst,
const int n_elements, const int ne10, const int ne11,
const int ne12, const int nb1, const int nb2,
const int offset, queue_ptr stream) {
@@ -277,7 +277,7 @@ void acc_f32_sycl(const float *x, const float *y, float *dst,
});
}
void gelu_f32_sycl(const float *x, float *dst, const int k,
static void gelu_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
stream->parallel_for(
@@ -289,7 +289,7 @@ void gelu_f32_sycl(const float *x, float *dst, const int k,
});
}
void silu_f32_sycl(const float *x, float *dst, const int k,
static void silu_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
stream->parallel_for(
@@ -301,7 +301,7 @@ void silu_f32_sycl(const float *x, float *dst, const int k,
});
}
void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
static void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
stream->parallel_for(
@@ -313,7 +313,7 @@ void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
});
}
void tanh_f32_sycl(const float *x, float *dst, const int k,
static void tanh_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
stream->parallel_for(
@@ -325,7 +325,7 @@ void tanh_f32_sycl(const float *x, float *dst, const int k,
});
}
void relu_f32_sycl(const float *x, float *dst, const int k,
static void relu_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
stream->parallel_for(
@@ -337,7 +337,7 @@ void relu_f32_sycl(const float *x, float *dst, const int k,
});
}
void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
stream->parallel_for(
@@ -349,7 +349,7 @@ void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
});
}
void hardswish_f32_sycl(const float *x, float *dst, const int k,
static void hardswish_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
stream->parallel_for(
@@ -361,7 +361,7 @@ void hardswish_f32_sycl(const float *x, float *dst, const int k,
});
}
void exp_f32_sycl(const float *x, float *dst, const int k,
static void exp_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
stream->parallel_for(
@@ -373,7 +373,7 @@ void exp_f32_sycl(const float *x, float *dst, const int k,
});
}
void log_f32_sycl(const float *x, float *dst, const int k,
static void log_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
stream->parallel_for(
@@ -385,7 +385,7 @@ void log_f32_sycl(const float *x, float *dst, const int k,
});
}
void neg_f32_sycl(const float *x, float *dst, const int k,
static void neg_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
stream->parallel_for(
@@ -397,7 +397,7 @@ void neg_f32_sycl(const float *x, float *dst, const int k,
});
}
void step_f32_sycl(const float *x, float *dst, const int k,
static void step_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
stream->parallel_for(
@@ -409,7 +409,7 @@ void step_f32_sycl(const float *x, float *dst, const int k,
});
}
void sigmoid_f32_sycl(const float *x, float *dst, const int k,
static void sigmoid_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
stream->parallel_for(
@@ -421,7 +421,7 @@ void sigmoid_f32_sycl(const float *x, float *dst, const int k,
});
}
void sqrt_f32_sycl(const float *x, float *dst, const int k,
static void sqrt_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
stream->parallel_for(
@@ -433,7 +433,7 @@ void sqrt_f32_sycl(const float *x, float *dst, const int k,
});
}
void sin_f32_sycl(const float *x, float *dst, const int k,
static void sin_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
stream->parallel_for(
@@ -445,7 +445,7 @@ void sin_f32_sycl(const float *x, float *dst, const int k,
});
}
void cos_f32_sycl(const float *x, float *dst, const int k,
static void cos_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
stream->parallel_for(
@@ -457,7 +457,7 @@ void cos_f32_sycl(const float *x, float *dst, const int k,
});
}
void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
static void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
const float negative_slope,
queue_ptr stream) {
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
@@ -470,7 +470,7 @@ void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
});
}
void sqr_f32_sycl(const float *x, float *dst, const int k,
static void sqr_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
stream->parallel_for(
@@ -482,7 +482,7 @@ void sqr_f32_sycl(const float *x, float *dst, const int k,
});
}
void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
const int nb02, const int nb03, const int ne10, const int ne11,
const int ne12, const int ne13, const float sf0, const float sf1,
const float sf2, const float sf3, queue_ptr stream) {
@@ -496,7 +496,7 @@ void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01
});
}
void pad_f32_sycl(const float *x, float *dst, const int ne00,
static void pad_f32_sycl(const float *x, float *dst, const int ne00,
const int ne01, const int ne02, const int ne0,
const int ne1, const int ne2, queue_ptr stream) {
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;

View File

@@ -207,7 +207,7 @@ static void get_rows_sycl_reorder(ggml_backend_sycl_context & ctx, const ggml_te
const size_t nrows = ne01;
const sycl::half* src0_dq = (const sycl::half*)(src0_q + nrows * ncols / 2);
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
k_get_rows_reorder<qk, qr, dq_reorder>(
src0_dd, src0_dq, src1_dd, dst_dd, ne00, ne12, s1, s2,
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
@@ -302,7 +302,6 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *s
// TODO: k-quants
GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
GGML_ABORT("fatal error");
break;
}
}

View File

@@ -46,6 +46,7 @@
static bool g_sycl_loaded = false;
int g_ggml_sycl_debug = 0;
int g_ggml_sycl_disable_optimize = 0;
int g_ggml_sycl_disable_graph = 0;
static ggml_sycl_device_info ggml_sycl_init() {
ggml_sycl_device_info info = {};
@@ -95,7 +96,7 @@ const ggml_sycl_device_info & ggml_sycl_info() {
return info;
}
void print_device_detail(int id, sycl::device &device, std::string device_type) {
static void print_device_detail(int id, sycl::device &device, std::string device_type) {
dpct::device_info prop;
SYCL_CHECK(CHECK_TRY_ERROR(
@@ -118,7 +119,7 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
}
void print_device_opt_feature(int device_count) {
static void print_device_opt_feature(int device_count) {
GGML_LOG_INFO("SYCL Optimization Feature:\n");
GGML_LOG_INFO(
"|ID| Device Type|Reorder|\n");
@@ -191,10 +192,12 @@ static void ggml_check_sycl() try {
if (!initialized) {
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
GGML_LOG_INFO("Running with Environment Variables:\n");
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
GGML_LOG_INFO("Build with Macros:\n");
#if defined(GGML_SYCL_FORCE_MMQ)
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
@@ -333,10 +336,11 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
assert(tensor->view_src->buffer->buft == buffer->buft);
return GGML_STATUS_SUCCESS;
}
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
tensor->extra = extra;
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
if (tensor->type == GGML_TYPE_Q4_0) {
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
tensor->extra = extra;
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
}
if (ggml_is_quantized(tensor->type)) {
// initialize padding to 0 to avoid possible NaN values
@@ -400,7 +404,7 @@ catch (sycl::exception const &exc) {
std::exit(1);
}
void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
const void *ptr_src, size_t size) {
char *host_buf = (char *)malloc(size);
q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();
@@ -486,6 +490,22 @@ catch (sycl::exception const &exc) {
std::exit(1);
}
static void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) {
GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
if (buffer == nullptr) {
return;
}
ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
if (ctx != nullptr) {
for (ggml_tensor_extra_gpu * extra : ctx->tensor_extras) {
release_extra_gpu(extra);
}
ctx->tensor_extras.clear(); // reset the tensor_extras vector
}
}
static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
/* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
/* .get_base = */ ggml_backend_sycl_buffer_get_base,
@@ -495,7 +515,7 @@ static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
/* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
/* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
/* .clear = */ ggml_backend_sycl_buffer_clear,
/* .reset = */ NULL,
/* .reset = */ ggml_backend_sycl_buffer_reset,
};
// sycl buffer type
@@ -576,7 +596,6 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
auto dev_count = ggml_backend_sycl_get_device_count();
@@ -604,7 +623,7 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
return &ggml_backend_sycl_buffer_types[device];
}
ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
static ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
int device = ctx->device;
@@ -1666,7 +1685,7 @@ static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
stream->parallel_for(
sycl::nd_range<3>(num_blocks * block_size, block_size),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
});
}
@@ -1687,7 +1706,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,
nchannels_y, item_ct1);
});
@@ -1707,7 +1726,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
row_stride_x, channel_stride_x,
nchannels_y / nchannels_x, item_ct1);
@@ -1748,7 +1767,7 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
const sycl::range<3> block_nums(1, nrows, 1);
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
k_sum_rows_f32(x, dst, ncols, item_ct1);
});
}
@@ -2680,6 +2699,12 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
@@ -2898,7 +2923,7 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
return false;
}
bool ggml_sycl_supports_dmmv(enum ggml_type type) {
static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
@@ -3113,8 +3138,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
const int64_t i2 = i12;
src0_row.data = src0_original + i02*nb02;
src1_row.data = src1_original + + i11*nb11 + i12*nb12;
dst_row.data = dst_original + i1*nb1 + i2*nb2;
src1_row.data = src1_original + i11*nb11 + i12*nb12;
dst_row.data = dst_original + i1*nb1 + i2*nb2;
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
}
@@ -3271,7 +3296,7 @@ static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
}
void ggml_sycl_set_main_device(const int main_device) try {
static void ggml_sycl_set_main_device(const int main_device) try {
if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
return;
}
@@ -3292,7 +3317,7 @@ catch (sycl::exception const &exc) {
std::exit(1);
}
bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) {
static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) {
if (!g_sycl_loaded) return false;
if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
@@ -3394,6 +3419,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
case GGML_OP_RMS_NORM:
ggml_sycl_rms_norm(ctx, dst);
break;
case GGML_OP_L2_NORM:
ggml_sycl_l2_norm(ctx, dst);
break;
case GGML_OP_MUL_MAT:
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
return false;
@@ -3471,6 +3499,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
case GGML_OP_RWKV_WKV6:
ggml_sycl_op_rwkv_wkv6(ctx, dst);
break;
case GGML_OP_RWKV_WKV7:
ggml_sycl_op_rwkv_wkv7(ctx, dst);
break;
case GGML_OP_GATED_LINEAR_ATTN:
ggml_sycl_op_gated_linear_attn(ctx, dst);
break;
@@ -3610,7 +3641,7 @@ catch (sycl::exception const &exc) {
std::exit(1);
}
void reorder_qw(char *data_device, const int ncols, const int nrows,
static void reorder_qw(char *data_device, const int ncols, const int nrows,
size_t size, size_t offset, dpct::queue_ptr stream) {
auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
SYCL_CHECK(
@@ -3624,7 +3655,7 @@ void reorder_qw(char *data_device, const int ncols, const int nrows,
stream->parallel_for(
size / sizeof(block_q4_0),
[=](auto i) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
[=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
const block_q4_0* x = (const block_q4_0*)tmp_buf;
const int ib = i;
@@ -3638,7 +3669,7 @@ void reorder_qw(char *data_device, const int ncols, const int nrows,
sycl::free(tmp_buf, *stream);
}
void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
static void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
char*data_device = (char*)src0->data;
size_t ncols = src0->ne[0];
size_t nrows = src0->ne[1];
@@ -3647,7 +3678,7 @@ void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
reorder_qw(data_device, ncols, nrows, size, 0, stream);
}
void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
static void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
ggml_tensor *src0 = dst->src[0];
ggml_tensor *src1 = dst->src[1];
@@ -3660,7 +3691,7 @@ void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
}
}
void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) {
static void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) {
dpct::queue_ptr stream = ctx->stream();
if (ctx->optimized_graph) {
return;
@@ -3671,10 +3702,9 @@ void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx)
if (ctx->opt_feature.reorder) opt_for_reorder(cgraph->nodes[i], stream);
}
}
static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
ggml_sycl_set_main_device(sycl_ctx->device);
static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * sycl_ctx, ggml_cgraph * cgraph) {
ggml_sycl_set_main_device(sycl_ctx->device);
if (!g_ggml_sycl_disable_optimize) optimize_graph_once(cgraph, sycl_ctx);
for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -3696,7 +3726,46 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
}
GGML_ASSERT(ok);
}
}
static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
auto * sycl_ctx = static_cast<ggml_backend_sycl_context *>(backend->context);
#ifdef GGML_SYCL_GRAPH
if (!g_ggml_sycl_disable_graph) {
if (!sycl_ctx->exec_graph && !dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph)) {
GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device);
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
return GGML_STATUS_SUCCESS;
}
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()));
model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
model_sycl_graph.end_recording();
if (!sycl_ctx->exec_graph) {
auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
sycl_ctx->exec_graph = std::make_unique<
sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
} else {
try {
sycl_ctx->exec_graph->update(model_sycl_graph);
GGML_SYCL_DEBUG("[SYCL-GRAPH] update success\n");
} catch (sycl::exception const & e) {
GGML_SYCL_DEBUG("[SYCL-GRAPH] Exception when updating graph, %s\n", e.what());
auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
sycl_ctx->exec_graph = std::make_unique<
sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
}
}
sycl_ctx->stream()->ext_oneapi_graph(*(sycl_ctx->exec_graph));
} else
#endif
{
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
}
return GGML_STATUS_SUCCESS;
}
@@ -3761,7 +3830,6 @@ bool ggml_backend_is_sycl(ggml_backend_t backend) {
}
int ggml_backend_sycl_get_device_count() {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count\n");
return ggml_sycl_info().device_count;
}
@@ -3851,7 +3919,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return true;
}
return false;
} break;
}
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_NEG:
@@ -3869,7 +3937,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
default:
return false;
}
break;
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
{
@@ -3900,7 +3967,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return false;
}
return true;
} break;
}
case GGML_OP_OUT_PROD:
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
case GGML_OP_GET_ROWS:
@@ -3917,7 +3984,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
default:
return false;
}
} break;
}
case GGML_OP_CPY:
{
ggml_type src0_type = op->src[0]->type;
@@ -3968,12 +4035,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return true;
}
return false;
} break;
}
case GGML_OP_CONCAT:
{
ggml_type src0_type = op->src[0]->type;
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
} break;
}
case GGML_OP_DUP:
case GGML_OP_ARGMAX:
case GGML_OP_NONE:
@@ -3997,6 +4064,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return (op->src[0]->type == GGML_TYPE_F32);
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_SCALE:
@@ -4030,6 +4098,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_LEAKY_RELU:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_GATED_LINEAR_ATTN:
return true;
default:

View File

@@ -3017,7 +3017,6 @@ void ggml_sycl_op_mul_mat_q(
break;
default:
GGML_ABORT("fatal error");
break;
}
GGML_UNUSED(src1);

View File

@@ -3,44 +3,42 @@
#include <cassert>
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
const sycl::nd_item<3> &item_ct1) {
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1);
static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const int ncols, const int nrows, const sycl::nd_item<3> & item_ct1) {
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
if (row >= nrows) {
return;
}
const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
assert(blocks_per_warp>0);
const int blocks_per_row = ncols / qk;
constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi; // Ensuring blocks_per_warp > 0
// partial sum for each thread
assert(blocks_per_warp > 0);
// partial sum for each thread
float tmp = 0.0f;
const block_q_t * x = (const block_q_t *) vx;
const block_q_t * x = (const block_q_t *) vx;
const block_q8_1 * y = (const block_q8_1 *) vy;
for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
i += blocks_per_warp) {
const int ibx = row*blocks_per_row + i; // x block index
for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) {
const int ibx = row * blocks_per_row + i; // x block index
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
const int iby = i * (qk / QK8_1); // y block index that aligns with ibx
const int iqs =
vdr *
(item_ct1.get_local_id(2) %
(qi / vdr)); // x block quant index when casting the quants to int
for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) {
const int iqs = elem + vdr * (item_ct1.get_local_id(2) %
(qi / vdr)); // x block quant index when casting the quants to int
tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
}
}
// sum up partial sums and write back result
#pragma unroll
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}
if (item_ct1.get_local_id(2) == 0) {
@@ -62,7 +60,7 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
}
const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
const int blocks_per_warp = vdr * WARP_SIZE / qi;
assert(blocks_per_warp>0);
// partial sum for each thread
@@ -87,7 +85,7 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
// sum up partial sums and write back result
#pragma unroll
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}
@@ -111,7 +109,7 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
}
const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
const int blocks_per_warp = vdr * WARP_SIZE / qi;
assert(blocks_per_warp>0);
// partial sum for each thread
float tmp = 0.0f;
@@ -135,7 +133,7 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
// sum up partial sums and write back result
#pragma unroll
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}
@@ -159,7 +157,7 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
}
const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
const int blocks_per_warp = vdr * WARP_SIZE / qi;
assert(blocks_per_warp>0);
// partial sum for each thread
float tmp = 0.0f;
@@ -183,7 +181,7 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
// sum up partial sums and write back result
#pragma unroll
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}
@@ -207,7 +205,7 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
}
const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
const int blocks_per_warp = vdr * WARP_SIZE / qi;
assert(blocks_per_warp>0);
// partial sum for each thread
float tmp = 0.0f;
@@ -231,7 +229,7 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
// sum up partial sums and write back result
#pragma unroll
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}
@@ -255,7 +253,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
}
const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
const int blocks_per_warp = vdr * WARP_SIZE / qi;
assert(blocks_per_warp>0);
// partial sum for each thread
float tmp = 0.0f;
@@ -279,7 +277,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
// sum up partial sums and write back result
#pragma unroll
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}
@@ -303,7 +301,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
}
const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
const int blocks_per_warp = vdr * WARP_SIZE / qi;
assert(blocks_per_warp>0);
// partial sum for each thread
float tmp = 0.0f;
@@ -327,7 +325,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
// sum up partial sums and write back result
#pragma unroll
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}
@@ -351,7 +349,7 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
}
const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
const int blocks_per_warp = vdr * WARP_SIZE / qi;
assert(blocks_per_warp>0);
// partial sum for each thread
float tmp = 0.0f;
@@ -375,7 +373,7 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
// sum up partial sums and write back result
#pragma unroll
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}
@@ -399,7 +397,7 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
}
const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
const int blocks_per_warp = vdr * WARP_SIZE / qi;
assert(blocks_per_warp>0);
// partial sum for each thread
float tmp = 0.0f;
@@ -423,7 +421,7 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
// sum up partial sums and write back result
#pragma unroll
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}
@@ -448,7 +446,7 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
}
const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
const int blocks_per_warp = vdr * WARP_SIZE / qi;
assert(blocks_per_warp>0);
// partial sum for each thread
float tmp = 0.0f;
@@ -472,7 +470,7 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
// sum up partial sums and write back result
#pragma unroll
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}
@@ -489,7 +487,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK4_0 == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
@@ -497,7 +495,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
@@ -513,7 +511,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK4_1 == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
@@ -521,7 +519,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
@@ -537,7 +535,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK5_0 == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
@@ -545,7 +543,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
@@ -561,7 +559,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK5_1 == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
@@ -569,7 +567,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
@@ -585,7 +583,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK8_0 == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
@@ -593,7 +591,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
@@ -609,7 +607,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
@@ -617,7 +615,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
@@ -633,7 +631,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
@@ -641,7 +639,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
@@ -657,7 +655,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
@@ -665,7 +663,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
@@ -681,7 +679,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
@@ -689,7 +687,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
@@ -705,7 +703,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
@@ -713,7 +711,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
@@ -730,13 +728,13 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
@@ -751,13 +749,13 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
@@ -772,14 +770,14 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
@@ -794,14 +792,14 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
@@ -816,14 +814,14 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
@@ -838,14 +836,14 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
@@ -860,13 +858,13 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
@@ -881,14 +879,14 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK4_NL == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
vx, vy, dst, ncols, nrows, item_ct1);
});
@@ -903,14 +901,14 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
@@ -1005,7 +1003,6 @@ void ggml_sycl_op_mul_mat_vec_q(
break;
default:
GGML_ABORT("fatal error");
break;
}
}
GGML_UNUSED(src1);

View File

@@ -180,6 +180,50 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
}
}
static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1);
const int tid = item_ct1.get_local_id(2);
const int nthreads = item_ct1.get_local_range(2);
const int nwarps = nthreads / WARP_SIZE;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row * ncols + col];
tmp += xi * xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp, item_ct1);
if (block_size > WARP_SIZE) {
int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
/*
DPCT1118:3: SYCL group functions and algorithms must be encountered in
converged control flow. You may need to adjust the code.
*/
item_ct1.barrier(sycl::access::fence_space::local_space);
size_t nreduce = nwarps / WARP_SIZE;
tmp = 0.f;
for (size_t i = 0; i < nreduce; i += 1)
{
tmp += s_sum[lane_id + i * WARP_SIZE];
}
tmp = warp_reduce_sum(tmp, item_ct1);
}
const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
for (int col = tid; col < ncols; col += block_size) {
dst[row * ncols + col] = scale * x[row * ncols + col];
}
}
static void norm_f32_sycl(const float* x, float* dst, const int ncols,
const int nrows, const float eps,
queue_ptr stream, int device) {
@@ -191,7 +235,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
norm_f32(x, dst, ncols, eps, item_ct1,
nullptr, WARP_SIZE);
});
@@ -214,7 +258,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
norm_f32(x, dst, ncols, eps, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size);
});
@@ -233,7 +277,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
group_norm_f32(
x, dst, group_size, ne_elements, eps_ct4, item_ct1,
nullptr, WARP_SIZE);
@@ -260,7 +304,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
group_norm_f32(x, dst, group_size, ne_elements,
eps_ct4, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size);
@@ -281,7 +325,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, eps, item_ct1,
nullptr, WARP_SIZE);
});
@@ -303,7 +347,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, eps, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size);
});
@@ -311,6 +355,48 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
}
}
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
const int nrows, const float eps,
queue_ptr stream, int device) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1,
nullptr, WARP_SIZE);
});
});
}
else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
const sycl::range<3> block_dims(1, 1, work_group_size);
/*
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
cgh);
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size);
});
});
}
}
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
ggml_tensor* dst, const float* src0_dd,
const float* src1_dd, float* dst_dd,
@@ -376,3 +462,25 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
(void)dst;
(void)src1_dd;
}
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
const ggml_tensor* src1, ggml_tensor* dst,
const float* src0_dd, const float* src1_dd,
float* dst_dd,
const queue_ptr& main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
(void)src1;
(void)dst;
(void)src1_dd;
}

View File

@@ -32,4 +32,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
float* dst_dd,
const queue_ptr& main_stream);
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
const ggml_tensor* src1, ggml_tensor* dst,
const float* src0_dd, const float* src1_dd,
float* dst_dd,
const queue_ptr& main_stream);
#endif // GGML_SYCL_NORM_HPP

View File

@@ -132,7 +132,7 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst,
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
nrows_y, scale, max_bias, m0,
m1, n_head_log2, item_ct1,

305
ggml/src/ggml-sycl/wkv.cpp Normal file
View File

@@ -0,0 +1,305 @@
#include <sycl/sycl.hpp>
#include "wkv.hpp"
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
// Helper function for the main kernel
template <int block_size>
static void rwkv_wkv6_f32_kernel(
const int B, const int T, const int C, const int H,
const float* k, const float* v, const float* r,
const float* tf, const float* td, const float* s,
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
const int tid = item_ct1.get_local_id(2);
const int bid = item_ct1.get_group(2);
const int head_size = block_size;
const int batch_i = bid / H;
const int head_i = bid % H;
const int state_size = C * head_size;
const int n_seq_tokens = T / B;
// Set up shared memory pointers
float* _k = shared_mem;
float* _r = _k + head_size;
float* _tf = _r + head_size;
float* _td = _tf + head_size;
// Local state array
float state[block_size];
// Load initial state
#pragma unroll
for (int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
}
// Sync threads before shared memory operations
item_ct1.barrier(sycl::access::fence_space::local_space);
// Load time-mixing parameters
_tf[tid] = tf[head_i * head_size + tid];
item_ct1.barrier(sycl::access::fence_space::local_space);
// Main sequence processing loop
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
t += C) {
item_ct1.barrier(sycl::access::fence_space::local_space);
// Load current timestep data to shared memory
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
item_ct1.barrier(sycl::access::fence_space::local_space);
const float _v = v[t];
float y = 0;
// Process in chunks of 4 for better vectorization
sycl::float4 k4, r4, tf4, td4, s4;
#pragma unroll
for (int j = 0; j < head_size; j += 4) {
// Load data in vec4 chunks
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
// Compute key-value product
sycl::float4 kv4 = k4 * _v;
// Accumulate weighted sum
y += sycl::dot(r4, tf4 * kv4 + s4);
// Update state
s4 = s4 * td4 + kv4;
// Store updated state
state[j] = s4.x();
state[j+1] = s4.y();
state[j+2] = s4.z();
state[j+3] = s4.w();
}
dst[t] = y;
}
// Save final state
#pragma unroll
for (int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
}
}
template <int block_size>
static void rwkv_wkv7_f32_kernel(
const int B, const int T, const int C, const int H,
const float* r, const float* w, const float* k, const float* v,
const float* a, const float* b, const float* s,
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
const int tid = item_ct1.get_local_id(2);
const int bid = item_ct1.get_group(2);
const int head_size = block_size;
const int batch_i = bid / H;
const int head_i = bid % H;
const int state_size = C * head_size;
const int n_seq_tokens = T / B;
float* _r = shared_mem;
float* _w = _r + head_size;
float* _k = _w + head_size;
float* _a = _k + head_size;
float* _b = _a + head_size;
float state[block_size];
#pragma unroll
for (int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
}
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
t += C) {
item_ct1.barrier(sycl::access::fence_space::local_space);
_r[tid] = r[t];
_w[tid] = w[t];
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];
item_ct1.barrier(sycl::access::fence_space::local_space);
const float _v = v[t];
float y = 0, sa = 0;
sycl::float4 a4, s4;
#pragma unroll
for (int j = 0; j < head_size; j += 4) {
a4 = sycl::float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
sa += sycl::dot(a4, s4);
}
sycl::float4 r4, w4, k4, b4;
#pragma unroll
for (int j = 0; j < head_size; j += 4) {
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
w4 = sycl::float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
b4 = sycl::float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
sycl::float4 kv4 = k4 * _v;
s4 = s4 * w4 + kv4 + sa * b4;
y += sycl::dot(r4, s4);
state[j] = s4.x();
state[j+1] = s4.y();
state[j+2] = s4.z();
state[j+3] = s4.w();
}
dst[t] = y;
}
#pragma unroll
for (int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
}
}
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
const ggml_tensor *src0 = dst->src[0];
const ggml_tensor *src1 = dst->src[1];
const float* k_d = (const float*)dst->src[0]->data;
const float* v_d = (const float*)dst->src[1]->data;
const float* r_d = (const float*)dst->src[2]->data;
const float* tf_d = (const float*)dst->src[3]->data;
const float* td_d = (const float*)dst->src[4]->data;
const float* s_d = (const float*)dst->src[5]->data;
float* dst_d = (float*)dst->data;
const int64_t B = dst->src[5]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
dpct::queue_ptr stream = ctx.stream();
// Calculate execution configuration
const size_t shared_mem_size = C / H * 4 * sizeof(float); // For k, r, tf, td
sycl::range<3> block_dims(1, 1, C / H);
sycl::range<3> grid_dims(1, 1, B * H);
// Submit kernel
if (C / H == WKV_BLOCK_SIZE) {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
} else {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
}
GGML_UNUSED(src0);
GGML_UNUSED(src1);
}
void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
const ggml_tensor *src0 = dst->src[0];
const ggml_tensor *src1 = dst->src[1];
const float* r_d = (const float*)dst->src[0]->data;
const float* w_d = (const float*)dst->src[1]->data;
const float* k_d = (const float*)dst->src[2]->data;
const float* v_d = (const float*)dst->src[3]->data;
const float* a_d = (const float*)dst->src[4]->data;
const float* b_d = (const float*)dst->src[5]->data;
const float* s_d = (const float*)dst->src[6]->data;
float* dst_d = (float*)dst->data;
const int64_t B = dst->src[6]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2);
dpct::queue_ptr stream = ctx.stream();
// Calculate execution configuration
const size_t shared_mem_size = C / H * 5 * sizeof(float); // For r, w, k, a, b
sycl::range<3> block_dims(1, 1, C / H);
sycl::range<3> grid_dims(1, 1, B * H);
// Submit kernel
if (C / H == WKV_BLOCK_SIZE) {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
} else {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
}
GGML_UNUSED(src0);
GGML_UNUSED(src1);
}

View File

@@ -0,0 +1,10 @@
#ifndef GGML_SYCL_WKV_HPP
#define GGML_SYCL_WKV_HPP
#include "common.hpp"
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_WKV_HPP

View File

@@ -1,143 +0,0 @@
#include <sycl/sycl.hpp>
#include "wkv6.hpp"
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
// Helper function for the main kernel
static void rwkv_wkv_f32_kernel(
const int B, const int T, const int C, const int H,
const float* k, const float* v, const float* r,
const float* tf, const float* td, const float* s,
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
const int tid = item_ct1.get_local_id(2);
const int bid = item_ct1.get_group(2);
const int head_size = WKV_BLOCK_SIZE;
const int batch_i = bid / H;
const int head_i = bid % H;
const int state_size = C * head_size;
const int n_seq_tokens = T / B;
// Set up shared memory pointers
float* _k = shared_mem;
float* _r = _k + head_size;
float* _tf = _r + head_size;
float* _td = _tf + head_size;
// Local state array
float state[WKV_BLOCK_SIZE];
// Load initial state
#pragma unroll
for (int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
}
// Sync threads before shared memory operations
item_ct1.barrier(sycl::access::fence_space::local_space);
// Load time-mixing parameters
_tf[tid] = tf[head_i * head_size + tid];
item_ct1.barrier(sycl::access::fence_space::local_space);
// Main sequence processing loop
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
t += C) {
item_ct1.barrier(sycl::access::fence_space::local_space);
// Load current timestep data to shared memory
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
item_ct1.barrier(sycl::access::fence_space::local_space);
const float _v = v[t];
float y = 0;
// Process in chunks of 4 for better vectorization
sycl::float4 k4, r4, tf4, td4, s4;
#pragma unroll
for (int j = 0; j < head_size; j += 4) {
// Load data in vec4 chunks
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
// Compute key-value product
sycl::float4 kv4 = k4 * _v;
// Accumulate weighted sum
y += sycl::dot(r4, tf4 * kv4 + s4);
// Update state
s4 = s4 * td4 + kv4;
// Store updated state
state[j] = s4.x();
state[j+1] = s4.y();
state[j+2] = s4.z();
state[j+3] = s4.w();
}
dst[t] = y;
}
// Save final state
#pragma unroll
for (int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
}
}
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
const ggml_tensor *src0 = dst->src[0];
const ggml_tensor *src1 = dst->src[1];
const float* k_d = (const float*)dst->src[0]->data;
const float* v_d = (const float*)dst->src[1]->data;
const float* r_d = (const float*)dst->src[2]->data;
const float* tf_d = (const float*)dst->src[3]->data;
const float* td_d = (const float*)dst->src[4]->data;
const float* s_d = (const float*)dst->src[5]->data;
float* dst_d = (float*)dst->data;
const int64_t B = dst->src[5]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
dpct::queue_ptr stream = ctx.stream();
// Calculate execution configuration
const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
sycl::range<3> block_dims(1, 1, C / H);
sycl::range<3> grid_dims(1, 1, B * H);
// Submit kernel
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv_f32_kernel(
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
GGML_UNUSED(src0);
GGML_UNUSED(src1);
}

View File

@@ -1,9 +0,0 @@
#ifndef GGML_SYCL_WKV6_HPP
#define GGML_SYCL_WKV6_HPP
#include "common.hpp"
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_WKV6_HPP

View File

@@ -29,6 +29,7 @@
#include "ggml-vulkan-shaders.hpp"
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
#define VK_VENDOR_ID_AMD 0x1002
@@ -149,6 +150,66 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf);
static constexpr uint32_t mul_mat_vec_max_cols = 8;
enum vk_device_architecture {
OTHER,
AMD_GCN,
AMD_RDNA1,
AMD_RDNA2,
AMD_RDNA3,
};
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
vk::PhysicalDeviceProperties props = device.getProperties();
if (props.vendorID == VK_VENDOR_ID_AMD) {
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
bool amd_shader_core_properties = false;
bool integer_dot_product = false;
bool subgroup_size_control = false;
for (const auto& properties : ext_props) {
if (strcmp("VK_AMD_shader_core_properties", properties.extensionName) == 0) {
amd_shader_core_properties = true;
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) {
integer_dot_product = true;
} else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
subgroup_size_control = true;
}
}
if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) {
return vk_device_architecture::OTHER;
}
vk::PhysicalDeviceProperties2 props2;
vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd;
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
props2.pNext = &shader_core_props_amd;
shader_core_props_amd.pNext = &integer_dot_props;
integer_dot_props.pNext = &subgroup_size_control_props;
device.getProperties2(&props2);
if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64) {
return vk_device_architecture::AMD_GCN;
}
if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32) {
// RDNA
if (shader_core_props_amd.wavefrontsPerSimd == 20) {
return vk_device_architecture::AMD_RDNA1;
}
if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated) {
return vk_device_architecture::AMD_RDNA3;
}
return vk_device_architecture::AMD_RDNA2;
}
}
return vk_device_architecture::OTHER;
}
struct vk_device_struct {
std::mutex mutex;
@@ -161,6 +222,7 @@ struct vk_device_struct {
bool pipeline_robustness;
vk::Device device;
uint32_t vendor_id;
vk_device_architecture architecture;
vk_queue compute_queue;
vk_queue transfer_queue;
bool single_queue;
@@ -242,6 +304,7 @@ struct vk_device_struct {
vk_pipeline pipeline_group_norm_f32;
vk_pipeline pipeline_rms_norm_f32;
vk_pipeline pipeline_rms_norm_back_f32;
vk_pipeline pipeline_l2_norm_f32;
vk_pipeline pipeline_gelu_f32;
vk_pipeline pipeline_gelu_quick_f32;
vk_pipeline pipeline_silu_f32;
@@ -266,6 +329,7 @@ struct vk_device_struct {
vk_pipeline pipeline_timestep_embedding_f32;
vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_opt_step_adamw_f32;
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
@@ -368,6 +432,7 @@ struct vk_mat_mat_push_constants {
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
uint32_t k_split;
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
uint32_t padded_N;
};
struct vk_mat_vec_push_constants {
uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
@@ -380,6 +445,7 @@ struct vk_mat_mat_id_push_constants {
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
uint32_t padded_N;
};
struct vk_mat_vec_id_push_constants {
uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
@@ -565,6 +631,13 @@ struct vk_op_rwkv_wkv6_push_constants {
uint32_t H;
};
struct vk_op_rwkv_wkv7_push_constants {
uint32_t B;
uint32_t T;
uint32_t C;
uint32_t H;
};
// Allow pre-recording command buffers
struct vk_staging_memcpy {
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -1445,6 +1518,73 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
return supported;
}
struct GpuPipelineConfig {
// GPU architecture identifier.
// Example: vk_device_architecture::AMD_GCN
vk_device_architecture arch;
// Mapping of pipeline names to their specific subgroup sizes.
// Example: {"soft_max_f32", 64}
std::unordered_map<std::string, uint32_t> pipelines;
// Default subgroup size for this GPU.
// Defaults to 0 if not explicitly provided.
uint32_t default_subgroup_size = 0;
};
// Pipeline configuration for RDNA1 GPUs.
static const std::unordered_map<std::string, uint32_t> rdna1_pipelines = {
{"soft_max", 64}, {"im2col", 64},
{"argmax", 64}, {"mul_mat_vec", 64},
{"mul_mat_vec_f16", 32}, {"mul_mat_vec_f32_f16", 32}
};
// Pipeline configuration for RDNA2 GPUs.
static const std::unordered_map<std::string, uint32_t> rdna2_pipelines = {
{"soft_max", 64}, {"im2col", 64},
};
static constexpr uint32_t RDNA_DEFAULT_SUBGROUP_SIZE = 32;
// Define configurations for different GPUs.
static std::vector<GpuPipelineConfig> gpu_pipeline_configs = {
{
vk_device_architecture::AMD_RDNA1,
{
rdna1_pipelines,
},
RDNA_DEFAULT_SUBGROUP_SIZE
},
{
vk_device_architecture::AMD_RDNA2,
{
rdna2_pipelines,
},
RDNA_DEFAULT_SUBGROUP_SIZE
},
};
static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_device_architecture &arch) {
for (const auto &config : gpu_pipeline_configs) {
if (config.arch == arch) {
auto pipIt = config.pipelines.find(pipeline_name);
if (pipIt != config.pipelines.end()) {
return pipIt->second;
}
std::vector<std::pair<std::string, uint32_t>> sorted_pipelines(config.pipelines.begin(), config.pipelines.end());
std::sort(sorted_pipelines.begin(), sorted_pipelines.end(),
[](const auto &a, const auto &b) { return a.first.size() > b.first.size(); });
for (const auto &entry : sorted_pipelines) {
if (pipeline_name.find(entry.first) != std::string::npos) {
return entry.second;
}
}
return config.default_subgroup_size;
}
}
return 0; // If no matching configuration is found
}
static void ggml_vk_load_shaders(vk_device& device) {
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
@@ -1466,36 +1606,36 @@ static void ggml_vk_load_shaders(vk_device& device) {
uint32_t l_align, m_align, s_align;
if (device->coopmat2) {
// spec constants and tile sizes for non-quant matmul/matmul_id
l_warptile = { 256, 128, 256, 64 };
m_warptile = { 256, 128, 128, 64 };
s_warptile = { 128, 64, 64, 64 };
l_warptile = { 256, 128, 256, 64, 1 };
m_warptile = { 256, 128, 128, 64, 0 };
s_warptile = { 128, 64, 64, 64, 0 };
l_wg_denoms = {128, 256, 1 };
m_wg_denoms = {128, 128, 1 };
s_wg_denoms = { 64, 64, 1 };
// spec constants and tile sizes for quant matmul (non-Qi_K)
l_warptile_mmq = { 256, 128, 256, 64 };
m_warptile_mmq = { 256, 128, 128, 64 };
s_warptile_mmq = { 256, 128, 128, 64 };
l_warptile_mmq = { 256, 128, 256, 64, 1 };
m_warptile_mmq = { 256, 128, 128, 64, 1 };
s_warptile_mmq = { 256, 32, 64, 128, 0 };
l_mmq_wg_denoms = { 128, 256, 1 };
m_mmq_wg_denoms = { 128, 128, 1 };
s_mmq_wg_denoms = { 128, 128, 1 };
s_mmq_wg_denoms = { 32, 64, 1 };
// spec constants and tile sizes for quant matmul (Qi_K)
l_warptile_mmq_k = { 256, 128, 512, 16 };
m_warptile_mmq_k = { 256, 128, 256, 16 };
s_warptile_mmq_k = { 256, 32, 128, 64 };
l_mmq_wg_denoms_k = { 128, 512, 1 };
m_mmq_wg_denoms_k = { 128, 256, 1 };
s_mmq_wg_denoms_k = { 32, 128, 1 };
l_warptile_mmq_k = { 256, 64, 128, 64, 1 };
m_warptile_mmq_k = { 256, 32, 64, 64, 0 };
s_warptile_mmq_k = { 256, 32, 32, 128, 0 };
l_mmq_wg_denoms_k = { 64, 128, 1 };
m_mmq_wg_denoms_k = { 32, 64, 1 };
s_mmq_wg_denoms_k = { 32, 32, 1 };
// spec constants and tile sizes for quant matmul_id
l_warptile_mmqid = { 256, 128, 128, 16 };
m_warptile_mmqid = { 256, 128, 64, 16 };
s_warptile_mmqid = { 256, 64, 64, 16 };
l_mmqid_wg_denoms = { 128, 128, 1 };
l_warptile_mmqid = { 256, 128, 64, 16, 0 };
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
l_mmqid_wg_denoms = { 128, 64, 1 };
m_mmqid_wg_denoms = { 128, 64, 1 };
s_mmqid_wg_denoms = { 64, 64, 1 };
s_mmqid_wg_denoms = { 128, 64, 1 };
l_align = 128;
m_align = 64;
@@ -1571,6 +1711,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
if (!require_full_subgroups && required_subgroup_size == 0) {
required_subgroup_size = get_subgroup_size(name, device->architecture);
}
if (!pipeline) {
pipeline = std::make_shared<vk_pipeline_struct>();
pipeline->name = name;
@@ -2128,6 +2272,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -2239,6 +2384,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
for (auto &c : compiles) {
@@ -2247,7 +2394,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
device->need_compiles = false;
}
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props);
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
static vk_device ggml_vk_get_device(size_t idx) {
VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
@@ -2276,6 +2423,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->physical_device = physical_devices[dev_num];
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
device->architecture = get_device_architecture(device->physical_device);
const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
@@ -2288,7 +2437,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
bool coopmat2_support = false;
device->coopmat_support = false;
// Check if maintenance4 is supported
for (const auto& properties : ext_props) {
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
maintenance4_support = true;
@@ -2376,13 +2524,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) {
device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE);
#if defined(_WIN32)
} else if (device->vendor_id == VK_VENDOR_ID_NVIDIA) {
} else {
// Limit batching of allocations to 1GB by default to avoid fragmentation issues
device->suballocation_block_size = 1024*1024*1024;
#endif
} else {
device->suballocation_block_size = device->max_memory_allocation_size;
}
device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
@@ -2401,7 +2545,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) {
if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props, device->architecture)) {
device->coopmat_support = false;
}
@@ -2779,7 +2923,10 @@ static void ggml_vk_print_gpu_info(size_t idx) {
subgroup_props.pNext = &driver_props;
physical_device.getProperties2(&props2);
const size_t subgroup_size = subgroup_props.subgroupSize;
vk_device_architecture arch = get_device_architecture(physical_device);
uint32_t default_subgroup_size = get_subgroup_size("", arch);
const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
bool fp16_storage = false;
@@ -2805,7 +2952,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
}
}
if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props)) {
const vk_device_architecture device_architecture = get_device_architecture(physical_device);
if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture)) {
coopmat_support = false;
}
@@ -3850,10 +3999,14 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
if (ctx->device->coopmat2) {
if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
// Use large shader when the N dimension is greater than the medium shader's tile size
uint32_t crossover_large = mmp->m->wg_denoms[1];
if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
return aligned ? mmp->a_l : mmp->l;
}
if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) {
// Use medium shader when the N dimension is greater than the small shader's tile size
uint32_t crossover_medium = mmp->s->wg_denoms[1];
if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_s : mmp->s;
@@ -3878,18 +4031,19 @@ static void ggml_vk_matmul(
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3) {
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
uint32_t padded_n) {
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
ggml_vk_sync_buffers(subctx);
if (split_k == 1) {
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 };
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch });
return;
}
GGML_ASSERT(batch_stride_d == m * n);
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3 };
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
// Make sure enough workgroups get assigned for split k to work
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
ggml_vk_sync_buffers(subctx);
@@ -3898,13 +4052,17 @@ static void ggml_vk_matmul(
}
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
if (ctx->device->coopmat2) {
if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
// Use large shader when the N dimension is greater than the medium shader's tile size
uint32_t crossover_large = mmp->m->wg_denoms[1];
if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
return aligned ? mmp->a_l : mmp->l;
}
if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) {
// Use medium shader when the N dimension is greater than the small shader's tile size
uint32_t crossover_medium = mmp->s->wg_denoms[1];
if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_s : mmp->s;
@@ -3929,14 +4087,15 @@ static void ggml_vk_matmul_id(
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) {
uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
uint32_t padded_n) {
VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
"m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
"batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
"n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
ggml_vk_sync_buffers(subctx);
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
nei0, nei1, nbi1, ne11 };
nei0, nei1, nbi1, ne11, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
}
@@ -4098,15 +4257,17 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
// Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
const int x_ne = ne01 * ne00;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
const int x_ne = ne01 * ne00;
const int y_ne = padded_n * ne10;
const int d_ne = ne11 * ne01;
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
@@ -4229,7 +4390,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
{ d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
ne01, ne11, ne10,
ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
split_k, ne12*ne13, ne02, ne12, r2, r3
split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
); // NOLINT
}
@@ -4680,15 +4841,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
// Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
const uint64_t x_ne = ne01 * ne00;
const uint64_t y_ne = ne11 * ne10;
const uint64_t d_ne = ne21 * ne20;
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
const uint64_t x_ne = ne01 * ne00;
const uint64_t y_ne = padded_n * ne10;
const uint64_t d_ne = ne21 * ne20;
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
@@ -4807,7 +4970,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
{ d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
ne01, ne21, ne10, ne10, ne10, ne01,
stride_batch_x, stride_batch_y, ne20*ne21,
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
); // NOLINT
}
@@ -5318,6 +5481,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_rms_norm_back_f32;
}
return nullptr;
case GGML_OP_L2_NORM:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_l2_norm_f32;
}
return nullptr;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(dst)) {
case GGML_UNARY_OP_SILU:
@@ -5457,6 +5625,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_rwkv_wkv6_f32;
}
return nullptr;
case GGML_OP_RWKV_WKV7:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rwkv_wkv7_f32;
}
return nullptr;
case GGML_OP_OPT_STEP_ADAMW:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_opt_step_adamw_f32;
@@ -5704,6 +5877,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_L2_NORM:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_SUM_ROWS:
@@ -5953,23 +6127,17 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
}, dryrun);
}
static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
const ggml_tensor * k = dst->src[0];
const ggml_tensor * v = dst->src[1];
const ggml_tensor * r = dst->src[2];
const ggml_tensor * tf = dst->src[3];
const ggml_tensor * td = dst->src[4];
const ggml_tensor * state = dst->src[5];
static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
GGML_ASSERT(version == 6 || version == 7);
int num_srcs = version == 6 ? 6 : 7;
for (int i = 0; i < num_srcs; i++) {
GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
}
GGML_ASSERT(!ggml_is_quantized(k->type));
GGML_ASSERT(!ggml_is_quantized(v->type));
GGML_ASSERT(!ggml_is_quantized(r->type));
GGML_ASSERT(!ggml_is_quantized(tf->type));
GGML_ASSERT(!ggml_is_quantized(td->type));
GGML_ASSERT(!ggml_is_quantized(state->type));
GGML_ASSERT(dst->buffer != nullptr);
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
GGML_ASSERT(pipeline != nullptr);
if (dryrun) {
@@ -5978,89 +6146,73 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
}
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
for (int i = 0; i < num_srcs; i++) {
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
}
ggml_vk_sync_buffers(subctx);
vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr;
size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0;
bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
for (int i = 0; i < num_srcs; i++) {
ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
srcs_uma[i] = d_srcs[i] != nullptr;
}
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
K_uma = d_K != nullptr;
V_uma = d_V != nullptr;
R_uma = d_R != nullptr;
TF_uma = d_TF != nullptr;
TD_uma = d_TD != nullptr;
STATE_uma = d_State != nullptr;
DST_uma = d_D != nullptr;
dst_uma = d_D != nullptr;
}
if (!K_uma) {
d_K = k_buf_ctx->dev_buffer;
k_offset = vk_tensor_offset(k) + k->view_offs;
uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 };
for (int i = 0; i < num_srcs; i++) {
src_sizes[i] = ggml_nbytes(dst->src[i]);
if (!srcs_uma[i]) {
d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
}
}
if (!V_uma) {
d_V = v_buf_ctx->dev_buffer;
v_offset = vk_tensor_offset(v) + v->view_offs;
}
if (!R_uma) {
d_R = r_buf_ctx->dev_buffer;
r_offset = vk_tensor_offset(r) + r->view_offs;
}
if (!TF_uma) {
d_TF = tf_buf_ctx->dev_buffer;
tf_offset = vk_tensor_offset(tf) + tf->view_offs;
}
if (!TD_uma) {
d_TD = td_buf_ctx->dev_buffer;
td_offset = vk_tensor_offset(td) + td->view_offs;
}
if (!STATE_uma) {
d_State = state_buf_ctx->dev_buffer;
state_offset = vk_tensor_offset(state) + state->view_offs;
}
if (!DST_uma) {
const uint64_t dst_size = ggml_nbytes(dst);
if (!dst_uma) {
d_D = dst_buf_ctx->dev_buffer;
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
}
const uint64_t k_size = ggml_nbytes(k);
const uint64_t v_size = ggml_nbytes(v);
const uint64_t r_size = ggml_nbytes(r);
const uint64_t tf_size = ggml_nbytes(tf);
const uint64_t td_size = ggml_nbytes(td);
const uint64_t state_size = ggml_nbytes(state);
const uint64_t dst_size = ggml_nbytes(dst);
std::array<uint32_t, 3> elements = {
(uint32_t)(pc.B * pc.H),
1,
1
};
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_K, k_offset, k_size },
vk_subbuffer{ d_V, v_offset, v_size },
vk_subbuffer{ d_R, r_offset, r_size },
vk_subbuffer{ d_TF, tf_offset, tf_size },
vk_subbuffer{ d_TD, td_offset, td_size },
vk_subbuffer{ d_State, state_offset, state_size },
vk_subbuffer{ d_D, dst_offset, dst_size }
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
if (version == 6) {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
vk_subbuffer{ d_D, dst_offset, dst_size }
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
} else if (version == 7) {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
vk_subbuffer{ d_D, dst_offset, dst_size }
}, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements);
} else {
// shouldn't happen
GGML_ASSERT(false);
}
}
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
@@ -6069,7 +6221,7 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
const size_t n_heads = dst->src[0]->ne[1];
const size_t n_seqs = dst->src[5]->ne[1];
ggml_vk_op_f32_rwkv6(
ggml_vk_op_f32_wkv(
ctx, subctx, dst,
{
(uint32_t)n_seqs,
@@ -6077,6 +6229,26 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
(uint32_t)n_embed,
(uint32_t)n_heads,
},
6,
dryrun
);
}
static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
const size_t seq_length = dst->src[0]->ne[2];
const size_t n_embed = dst->ne[0];
const size_t n_heads = dst->src[0]->ne[1];
const size_t n_seqs = dst->src[6]->ne[1];
ggml_vk_op_f32_wkv(
ctx, subctx, dst,
{
(uint32_t)n_seqs,
(uint32_t)seq_length,
(uint32_t)n_embed,
(uint32_t)n_heads,
},
7,
dryrun
);
}
@@ -6378,6 +6550,11 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
}
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
}
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
}
@@ -6767,7 +6944,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
m, n, k,
k, k, m, k*m, k*n, m*n,
split_k, batch, batch, batch, 1, 1
split_k, batch, batch, batch, 1, 1, n
);
}
ggml_vk_ctx_end(subctx);
@@ -7112,7 +7289,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
m, n, k,
k, k, m, k*m, k*n, m*n,
split_k, batch, batch, batch, 1, 1
split_k, batch, batch, batch, 1, 1, n
);
}
ggml_vk_ctx_end(subctx);
@@ -7373,6 +7550,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_L2_NORM:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
@@ -7389,6 +7567,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_OPT_STEP_ADAMW:
@@ -7435,6 +7614,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_L2_NORM:
case GGML_OP_UNARY:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
@@ -7552,6 +7732,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_RMS_NORM_BACK:
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
break;
case GGML_OP_L2_NORM:
ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun);
break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(node)) {
@@ -7642,6 +7826,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
break;
case GGML_OP_RWKV_WKV7:
ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun);
break;
case GGML_OP_OPT_STEP_ADAMW:
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
@@ -7715,6 +7904,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_L2_NORM:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
@@ -7734,6 +7924,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
@@ -8651,6 +8842,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_ADD:
case GGML_OP_SUB:
@@ -8680,6 +8872,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU:
case GGML_OP_OPT_STEP_ADAMW:
return true;
@@ -8826,7 +9019,7 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
UNUSED(instance_extensions);
}
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) {
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
switch (props.vendorID) {
case VK_VENDOR_ID_INTEL:
// Intel drivers don't support coopmat properly yet
@@ -8834,10 +9027,7 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope
case VK_VENDOR_ID_AMD:
if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
// Workaround for AMD proprietary driver reporting support on all GPUs
const std::string name = props.deviceName;
return name.rfind("AMD Radeon RX 7", 0) == 0 || name.rfind("AMD Radeon(TM) RX 7", 0) == 0 || // RDNA 3 consumer GPUs
name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs
name.rfind("AMD Radeon 7", 0) == 0 || name.rfind("AMD Radeon(TM) 7", 0) == 0; // RDNA 3 APUs
return arch == vk_device_architecture::AMD_RDNA3;
}
return true;
default:
@@ -9067,6 +9257,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
} else if (tensor->op == GGML_OP_SILU_BACK) {
tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_L2_NORM) {
const float eps = ((float *) tensor->op_params)[0];
tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
} else if (tensor->op == GGML_OP_SOFT_MAX) {
if (src1 != nullptr) {
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
@@ -9186,6 +9379,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
} else if (tensor->op == GGML_OP_RWKV_WKV7) {
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
src_clone[4], src_clone[5], src_clone[6]);
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
src_clone[0]->flags = src0->flags;
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],

View File

@@ -1,8 +1,4 @@
find_package (Threads REQUIRED)
find_program(GLSLC_EXECUTABLE glslc)
if(NOT GLSLC_EXECUTABLE)
message(FATAL_ERROR "glslc not found.")
endif()
set(TARGET vulkan-shaders-gen)
add_executable(${TARGET} vulkan-shaders-gen.cpp)

View File

@@ -178,7 +178,7 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
uvec4 v = bl128.block.q4k[0];
const f16vec2 loadd = unpackFloat2x16(v.x);
const vec2 loadd = vec2(unpackFloat2x16(v.x));
uint32_t sc;
uint32_t mbyte;
@@ -199,15 +199,15 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
sc &= 0x3F;
mbyte &= 0x3F;
const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte);
const float d = loadd.x * float(sc);
const float m = loadd.y * float(mbyte);
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
float16_t ret = d * float16_t(qs) - m;
float ret = d * float(qs) - m;
return ret;
return float16_t(ret);
}
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {

View File

@@ -0,0 +1,41 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE sum[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
sum[tid] += xi * xi;
}
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
sum[tid] += sum[tid + s];
}
barrier();
}
const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
}
}

View File

@@ -777,7 +777,7 @@ void main() {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
const uint row_i = dc + cm_col * TN + col + store_c;
if (row_i >= _ne1) break;

View File

@@ -23,6 +23,10 @@ layout (constant_id = 1) const uint BM = 64;
layout (constant_id = 2) const uint BN = 64;
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
layout (constant_id = 4) const bool enable_smaller_matrices = false;
const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
layout (push_constant) uniform parameter
{
uint M;
@@ -48,6 +52,8 @@ layout (push_constant) uniform parameter
uint broadcast2;
uint broadcast3;
#endif
// N dimension for the B matrix can be >= p.N
uint padded_N;
} p;
@@ -166,15 +172,13 @@ void main() {
const uint end_k = min(p.K, (ik + 1) * p.k_split);
#endif
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
#ifdef MUL_MAT_ID
uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K;
uint pos_b = 0;
#else
uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K;
uint pos_b = batch_idx * p.batch_stride_b;
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif
uint stride_a = p.stride_a / QUANT_K;
@@ -195,6 +199,7 @@ void main() {
tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
#if QUANT_K > 1
tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
@@ -202,18 +207,19 @@ void main() {
#endif
// Use end_k rather than p.K as the dimension because that's what
// we need to bound check against when using split_k
// we need to bound check against when using split_k.
// Bounds check B against padded_N, but bounds check D against N.
tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k);
tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.N, end_k);
tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.padded_N, end_k);
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M);
tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.N, end_k);
tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k);
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
#if !defined(MUL_MAT_ID)
// Detect a fast path where all loads are entirely in bounds and no clamping is required
if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
#if QUANT_K == 1
(stride_a % 8) == 0 &&
#endif
@@ -229,16 +235,54 @@ void main() {
tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
uint k_iters = (end_k - start_k + BK - 1) / BK;
if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
sum = coopMatMulAdd(mat_a, mat_b, sum);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
return;
} else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
return;
} else {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
return;
}
} else
#endif // !defined(MUL_MAT_ID)
@@ -251,6 +295,9 @@ void main() {
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
[[dont_unroll]]
for (uint block_k = start_k; block_k < end_k; block_k += BK) {
@@ -263,7 +310,7 @@ void main() {
#ifdef MUL_MAT_ID
bool unclampedB = true;
#else
bool unclampedB = (ic + 1) * BN <= p.N && block_k + BK <= end_k && (block_k % 8) == 0;
bool unclampedB = (ic + 1) * BN <= p.padded_N && block_k + BK <= end_k && (block_k % 8) == 0;
#endif
if (unclampedA && unclampedB) {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
@@ -293,19 +340,16 @@ void main() {
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
}
}
// Convert from ACC_TYPE to D_TYPE
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
// Convert from ACC_TYPE to D_TYPE
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
#ifdef MUL_MAT_ID
// Call callback to store each element, remapping row through shared memory
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
// Call callback to store each element, remapping row through shared memory
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
#else
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
#endif
}
}

View File

@@ -434,6 +434,7 @@ void process_shaders() {
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
@@ -528,6 +529,8 @@ void process_shaders() {
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
for (auto &c : compiles) {

Some files were not shown because too many files have changed in this diff Show More