Compare commits

...

3 Commits

Author SHA1 Message Date
Georgi Gerganov
78d70223c3 metal : use FA-vec kernel up to batch size 20
ggml-ci
2025-05-13 10:38:06 +03:00
Georgi Gerganov
fdfc7de7fc metal : optimize multi-sequence FA vec kernel
ggml-ci
2025-05-13 08:03:27 +03:00
Georgi Gerganov
f078c79865 batched-bench : fix pp batch contents 2025-05-13 07:55:30 +03:00
3 changed files with 8 additions and 3 deletions

View File

@@ -4358,7 +4358,7 @@ static bool ggml_metal_encode_node(
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
// for now avoiding mainly to keep the number of templates/kernels a bit lower
// these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
switch (src1->type) {
case GGML_TYPE_F16:
{

View File

@@ -3887,6 +3887,11 @@ kernel void kernel_flash_attn_ext_vec(
sm[tiisg] = pm[ic + tiisg];
}
// skip -INF blocks
if (simd_max(sm[tiisg]) == -INFINITY) {
continue;
}
// Q*K^T
{
// each simdgroup processes 1 query and NE (NW/NL) head elements

View File

@@ -123,8 +123,8 @@ int main(int argc, char ** argv) {
common_batch_clear(batch);
for (int i = 0; i < pp; ++i) {
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
for (int i = 0; i < pp; ++i) {
common_batch_add(batch, 0, i, { j }, false);
}
}