Compare commits

...

9 Commits
b7835 ... b7844

Author SHA1 Message Date
Gaurav Garg
a83c73a18a [CUDA] Reduce CPU-side stalls due to the CUDA command buffer being full (#19042)
* [CUDA] Reduce CPU-side stalls due to the CUDA command buffer being full

With pipeline parallelism, during prompt processing, the CPU-side CUDA command buffer gets full, stalling the CPU. Due to this, enough work doesn't get submitted to the GPU, causing bubbles in the GPU timeline.
Fix this by setting the CUDA environment variable CUDA_SCALE_LAUNCH_QUEUES to 4x to increase the command buffer size.

* Set the env variable in the CUDA backend registry allocation

* Add link to PR in code comment

* Remove warning logs and update documentation
2026-01-27 08:52:44 +02:00
Daniel Bevenius
fc3cdf32ce common : clarify HTTPS build options in error message (#19103)
* common : clarify HTTPS build options in error message

This commit updates the https error message to provide clearer
instructions for users who encounter the "HTTPS is not supported" error.

The motivation for this is that it might not be clear to users that only
one of these options are needed to enable HTTPS support.
The LLAMA_OPENSSL option is also added to the message to cover all
possible build configurations.

* clarify that OpenSSL is the default for HTTPS support
2026-01-27 06:16:00 +01:00
shalinib-ibm
7afdfc9b84 ggml-cpu: Enable FP16 MMA kernels on PPC (#19060) 2026-01-27 11:52:34 +08:00
lhez
94eeb5967c opencl: add flattened q6_K mv (#19054)
* opencl: flatten `q6_K` and add `kernel_mul_mv_q6_K_f32_flat`

* opencl: clean up

* opencl: refactor q6_K mv - put loop body in `block_q_6_K_dot_y_flat`

* opencl: tweak the workgroup size a bit

* opencl: output 4 values per subgroup for `kernel_mul_mv_q6_K_f32_flat`

* opencl: proper alignment for q6_K

* opencl: boundary handling for flattened q6_K mv

* opencl: rename q6_K mv kernel file

* opencl: put flattened q6_K mv in its own file

* opencl: use lower k in file name

* opencl: use K in variable names
2026-01-26 19:36:24 -08:00
Johannes Gäßler
b0311c16d2 CUDA: fix padding of GQA to power of 2 in FA (#19115) 2026-01-26 23:24:58 +01:00
Georgi Gerganov
8f80d1b254 graph : fix nkvo offload with FA (#19105) 2026-01-26 20:18:34 +02:00
Sigbjørn Skjæret
142cbe2ac6 ci : use new 1vCPU runner for lightweight jobs (#19107)
* use new 1vCPU runner for lightweight jobs

* pyright is too heavy, look into ty some day

use new pip-install input
2026-01-26 15:22:49 +01:00
Georgi Gerganov
56f3ebf38e model : add correct type for GLM 4.7 Flash (#19106) 2026-01-26 11:24:30 +02:00
Johannes Gäßler
0c21677e43 CUDA: faster FA for GQA > 1 but not power of 2 (#19092) 2026-01-25 21:19:47 +01:00
30 changed files with 749 additions and 115 deletions

View File

@@ -19,7 +19,7 @@ on:
jobs:
check-vendor:
runs-on: ubuntu-latest
runs-on: ubuntu-slim
steps:
- name: Checkout

View File

@@ -10,7 +10,7 @@ permissions:
jobs:
close-issues:
runs-on: ubuntu-latest
runs-on: ubuntu-slim
permissions:
issues: write
pull-requests: write

View File

@@ -20,7 +20,7 @@ concurrency:
jobs:
editorconfig:
runs-on: ubuntu-latest
runs-on: ubuntu-slim
steps:
- uses: actions/checkout@v6
- uses: editorconfig-checker/action-editorconfig-checker@v2

View File

@@ -21,7 +21,7 @@ on:
jobs:
deploy:
runs-on: ubuntu-latest
runs-on: ubuntu-slim
steps:
- uses: actions/checkout@v6

View File

@@ -7,7 +7,7 @@ jobs:
permissions:
contents: read
pull-requests: write
runs-on: ubuntu-latest
runs-on: ubuntu-slim
steps:
- uses: actions/checkout@v6
with:

View File

@@ -12,7 +12,7 @@ on:
jobs:
pre-tokenizer-hashes:
runs-on: ubuntu-latest
runs-on: ubuntu-slim
steps:
- name: Checkout repository

View File

@@ -20,7 +20,7 @@ concurrency:
jobs:
python-check-requirements:
runs-on: ubuntu-latest
runs-on: ubuntu-slim
name: check-requirements
steps:
- name: Check out source repository

View File

@@ -15,7 +15,7 @@ concurrency:
jobs:
flake8-lint:
runs-on: ubuntu-latest
runs-on: ubuntu-slim
name: Lint
steps:
- name: Check out source repository

View File

@@ -29,9 +29,7 @@ jobs:
uses: actions/setup-python@v6
with:
python-version: "3.11"
- name: Install Python dependencies
# TODO: use a venv
run: pip install -r requirements/requirements-all.txt
pip-install: -r requirements/requirements-all.txt
- name: Type-check with Pyright
uses: jakebailey/pyright-action@v2
with:

View File

@@ -14,7 +14,7 @@ on:
jobs:
update-ops-docs:
runs-on: ubuntu-latest
runs-on: ubuntu-slim
steps:
- name: Checkout repository

View File

@@ -8,7 +8,7 @@ on:
jobs:
update:
name: Update Winget Package
runs-on: ubuntu-latest
runs-on: ubuntu-slim
if: github.repository_owner == 'ggml-org'
steps:

View File

@@ -60,10 +60,10 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
#ifndef CPPHTTPLIB_OPENSSL_SUPPORT
if (parts.scheme == "https") {
throw std::runtime_error(
"HTTPS is not supported. Please rebuild with:\n"
"HTTPS is not supported. Please rebuild with one of:\n"
" -DLLAMA_BUILD_BORINGSSL=ON\n"
" -DLLAMA_BUILD_LIBRESSL=ON\n"
"or ensure dev files of an OpenSSL-compatible library are available when building."
" -DLLAMA_OPENSSL=ON (default, requires OpenSSL dev files installed)"
);
}
#endif

View File

@@ -248,6 +248,14 @@ You may set the [cuda environmental variables](https://docs.nvidia.com/cuda/cuda
CUDA_VISIBLE_DEVICES="-0" ./build/bin/llama-server --model /srv/models/llama.gguf
```
#### CUDA_SCALE_LAUNCH_QUEUES
The environment variable [`CUDA_SCALE_LAUNCH_QUEUES`](https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/environment-variables.html#cuda-scale-launch-queues) controls the size of CUDA's command buffer, which determines how many GPU operations can be queued before the CPU must wait for the GPU to catch up. A larger buffer reduces CPU-side stalls and allows more work to be queued on a GPU.
**Default behavior:** llama.cpp automatically sets `CUDA_SCALE_LAUNCH_QUEUES=4x`, which increases the CUDA command buffer to 4 times its default size. This optimization is particularly beneficial for **Multi-GPU setups with pipeline parallelism**, where it significantly improves prompt processing throughput by allowing more operations to be enqueued across GPUs.
See PR [#19042](https://github.com/ggml-org/llama.cpp/pull/19042) for performance benchmarks and technical details.
### Unified Memory
The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enable unified memory in Linux. This allows swapping to system RAM instead of crashing when the GPU VRAM is exhausted. In Windows this setting is available in the NVIDIA control panel as `System Memory Fallback`.

View File

@@ -1797,10 +1797,27 @@ class tinyBLAS_Q0_AVX {
} \
} \
template<typename T>
struct mma_instr;
template<>
struct mma_instr<ggml_bf16_t> {
static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
__builtin_mma_xvbf16ger2pp(acc, a, b);
}
};
template<>
struct mma_instr<ggml_fp16_t> {
static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
__builtin_mma_xvf16ger2pp(acc, a, b);
}
};
template <typename TA, typename TB, typename TC>
class tinyBLAS_BF16_PPC {
class tinyBLAS_HP16_PPC {
public:
tinyBLAS_BF16_PPC(int64_t k,
tinyBLAS_HP16_PPC(int64_t k,
const TA *A, int64_t lda,
const TB *B, int64_t ldb,
TC *C, int64_t ldc,
@@ -2118,8 +2135,8 @@ class tinyBLAS_BF16_PPC {
packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
for (int x = 0; x < 4; x++) {
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
}
}
SAVE_ACC(&acc_0, ii, jj);
@@ -2135,8 +2152,8 @@ class tinyBLAS_BF16_PPC {
packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
for (int x = 0; x < 4; x++) {
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
}
}
SAVE_ACC(&acc_0, ii, jj);
@@ -2155,10 +2172,10 @@ class tinyBLAS_BF16_PPC {
packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
for (int x = 0; x < 4; x++) {
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
__builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
__builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
mma_instr<TA>::outer_product(&acc_2, vec_A[x+4], vec_B[x]);
mma_instr<TA>::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]);
}
}
@@ -2189,7 +2206,7 @@ class tinyBLAS_BF16_PPC {
packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
for (int x = 0; x<2; x++) {
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
}
}
__builtin_mma_disassemble_acc(vec_C, &acc_0);
@@ -2224,8 +2241,8 @@ class tinyBLAS_BF16_PPC {
packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
for (int x = 0; x<4; x++) {
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
}
}
__builtin_mma_disassemble_acc(vec_C, &acc_0);
@@ -3418,16 +3435,19 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
return tb.matmul(m, n);
}
#elif defined(__MMA__)
if ((k % 8))
return false;
if(Btype == GGML_TYPE_BF16) {
tinyBLAS_BF16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
(const ggml_bf16_t *)A, lda,
(const ggml_bf16_t *)B, ldb,
(float *)C, ldc,
params->ith, params->nth};
tb.matmul(m, n);
return true;
if (k % 8) {
return false;
}
if (Btype == GGML_TYPE_BF16) {
tinyBLAS_HP16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
(const ggml_bf16_t *)A, lda,
(const ggml_bf16_t *)B, ldb,
(float *)C, ldc,
params->ith, params->nth };
tb.matmul(m, n);
return true;
}
#elif defined(__riscv_zvfbfwma)
#if LMUL == 1
@@ -3516,6 +3536,21 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
#endif
return tb.matmul(m, n);
}
#elif defined(__MMA__)
if (k % 8) {
return false;
}
if (Btype == GGML_TYPE_F16) {
tinyBLAS_HP16_PPC<ggml_fp16_t, ggml_fp16_t, float> tb{ k,
(const ggml_fp16_t *)A, lda,
(const ggml_fp16_t *)B, ldb,
(float *)C, ldc,
params->ith, params->nth };
tb.matmul(m, n);
return true;
}
#endif
return false;
}

View File

@@ -629,8 +629,8 @@ static __global__ void flash_attn_mask_to_KV_max(
template<int D, int ncols1, int ncols2> // D == head size
__launch_bounds__(D, 1)
static __global__ void flash_attn_stream_k_fixup(
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
const int nbatch_fa) {
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
const int ne11, const int ne12, const int nbatch_fa) {
constexpr int ncols = ncols1*ncols2;
const int bidx0 = blockIdx.x;
@@ -641,11 +641,14 @@ static __global__ void flash_attn_stream_k_fixup(
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
const bool did_not_have_any_data = kbc0 == kbc0_stop;
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -654,15 +657,19 @@ static __global__ void flash_attn_stream_k_fixup(
return;
}
const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
if (jt*ncols1 + j >= ne01) {
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
return;
}
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
// Load the partial result that needs a fixup:
float dst_val = 0.0f;
@@ -681,7 +688,7 @@ static __global__ void flash_attn_stream_k_fixup(
int bidx = bidx0 - 1;
int kbc_stop = kbc0;
while(true) {
const int kbc = int64_t(bidx)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
if (kbc == kbc_stop) { // Did not have any data.
bidx--;
kbc_stop = kbc;
@@ -882,8 +889,10 @@ void launch_fattn(
}
}
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int gqa_ratio = Q->ne[2] / K->ne[2];
const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
@@ -958,7 +967,7 @@ void launch_fattn(
blocks_num.x = ntiles_x;
blocks_num.y = parallel_blocks;
blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3];
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
@@ -1012,7 +1021,7 @@ void launch_fattn(
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);
}
} else if (parallel_blocks > 1) {
const dim3 block_dim_combine(DV, 1, 1);

View File

@@ -933,6 +933,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const float logit_softcap,
const uint3 ne01,
const int ne02,
const int gqa_ratio,
const int ne11,
const int stride_Q1,
const int stride_Q2,
@@ -940,6 +941,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int stride_V,
const int stride_mask,
const int jt,
const int zt_gqa,
const int kb0_start,
const int kb0_stop) {
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
@@ -1022,7 +1024,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int j = jc / ncols2;
const int c = jc % ncols2;
if (jt*ncols1 + j < int(ne01.z)) {
if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
@@ -1408,7 +1410,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int j_dst = jc_dst / ncols2;
const int c_dst = jc_dst % ncols2;
if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) {
if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) {
continue;
}
@@ -1447,7 +1449,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
}
#else
GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
scale, slope, logit_softcap, ne01, ne02,
scale, slope, logit_softcap, ne01, ne02, gqa_ratio,
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
jt, kb0_start, kb0_stop);
NO_DEVICE_CODE;
@@ -1520,12 +1522,13 @@ static __global__ void flash_attn_ext_f16(
const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
// kbc == k block continuous, current index in continuous ijk space.
int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1536,22 +1539,24 @@ static __global__ void flash_attn_ext_f16(
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
while (kbc < kbc_stop && kb0_stop == iter_k) {
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
const int head0 = zt * ncols2;
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
(const half *) (mask + nb33*(sequence % ne33));
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
if (KV_max) {
kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
@@ -1561,12 +1566,12 @@ static __global__ void flash_attn_ext_f16(
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
} else {
constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
}
kbc += iter_k;
@@ -1580,22 +1585,24 @@ static __global__ void flash_attn_ext_f16(
return;
}
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index.
const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
const int head0 = zt * ncols2;
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
(const half *) (mask + nb33*(sequence % ne33));
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
if (KV_max) {
kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
@@ -1605,7 +1612,7 @@ static __global__ void flash_attn_ext_f16(
constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
#else
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
@@ -1739,3 +1746,5 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);

View File

@@ -18,9 +18,11 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
}
}
if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
return;
if constexpr (ncols2 <= 16) {
if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
return;
}
}
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) {
@@ -33,6 +35,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
template <int DKQ, int DV>
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
@@ -60,17 +63,38 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
if (use_gqa_opt && gqa_ratio % 8 == 0) {
// On Volta the GQA optimizations aren't as impactful vs. minimizing wasted compute:
if (cc == GGML_CUDA_CC_VOLTA) {
if (use_gqa_opt && gqa_ratio % 8 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 4 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 2 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
return;
}
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio > 4) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 4 == 0) {
if (use_gqa_opt && gqa_ratio > 2) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 2 == 0) {
if (use_gqa_opt && gqa_ratio > 1) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
return;
}
@@ -79,6 +103,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
}
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
@@ -121,8 +146,30 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
GGML_ASSERT(gqa_ratio % 4 == 0);
if (gqa_ratio % 16 == 0) {
if (gqa_ratio == 20) { // GLM 4.7 Flash
if (cc >= GGML_CUDA_CC_BLACKWELL) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
break;
}
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
if (Q->ne[1] <= 4) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
break;
}
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
break;
}
if (cc >= GGML_CUDA_CC_TURING) {
if (Q->ne[1] <= 4) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 32>(ctx, dst);
break;
}
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
break;
}
// Volta:
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
} else if (gqa_ratio % 16 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
@@ -234,7 +281,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
// The effective batch size for the kernel can be increased by gqa_ratio.
// The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
for (const ggml_tensor * t : {Q, K, V, mask}) {
if (t == nullptr || ggml_is_quantized(t->type)) {
continue;
@@ -268,7 +315,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
if (V->ne[0] != 512) {
return BEST_FATTN_KERNEL_NONE;
}
if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
if (!gqa_opt_applies) {
return BEST_FATTN_KERNEL_NONE;
}
if (!V_is_K_view) {

View File

@@ -4876,6 +4876,16 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (!initialized) {
// Set CUDA_SCALE_LAUNCH_QUEUES before any CUDA API call to improve multi-GPU pipeline parallelism performance
// PR: https://github.com/ggml-org/llama.cpp/pull/19042
if (getenv("CUDA_SCALE_LAUNCH_QUEUES") == nullptr) {
#ifdef _WIN32
_putenv_s("CUDA_SCALE_LAUNCH_QUEUES", "4x");
#else
setenv("CUDA_SCALE_LAUNCH_QUEUES", "4x", 0); // don't overwrite if already set
#endif // _WIN32
}
ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);

View File

@@ -71,7 +71,7 @@ for type_k in TYPES_KV:
f.write(SOURCE_FATTN_VEC.format(type_k=type_k, type_v=type_v))
for ncols in [8, 16, 32, 64]:
for ncols2 in [1, 2, 4, 8, 16]:
for ncols2 in [1, 2, 4, 8, 16, 32]:
if ncols2 > ncols:
continue
ncols1 = ncols // ncols2
@@ -83,9 +83,9 @@ for ncols in [8, 16, 32, 64]:
continue
if head_size_kq == 72:
continue
if head_size_kq != 576 and ncols2 == 16:
if head_size_kq != 576 and ncols2 in (16, 32):
continue
if head_size_kq == 576 and ncols2 not in (4, 16):
if head_size_kq == 576 and ncols2 not in (4, 16, 32):
continue
head_size_v = head_size_kq if head_size_kq != 576 else 512
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))

View File

@@ -85,7 +85,8 @@ set(GGML_OPENCL_KERNELS
mul_mv_q4_0_f32_8x_flat
mul_mv_q4_0_f32_1d_8x_flat
mul_mv_q4_0_f32_1d_16x_flat
mul_mv_q6_k
mul_mv_q6_k_f32
mul_mv_q6_k_f32_flat
mul_mv_q8_0_f32
mul_mv_q8_0_f32_flat
mul_mv_mxfp4_f32

View File

@@ -533,8 +533,10 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
cl_kernel kernel_convert_block_q4_0_noshuffle;
cl_kernel kernel_restore_block_q4_0_noshuffle;
cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K;
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
cl_kernel kernel_mul_mv_q6_K_f32;
cl_kernel kernel_mul_mv_q6_K_f32_flat;
cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat;
cl_kernel kernel_mul_mv_q8_0_f32, kernel_mul_mv_q8_0_f32_flat;
cl_kernel kernel_solve_tri_f32;
@@ -892,6 +894,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err));
GGML_LOG_CONT(".");
}
@@ -1114,14 +1118,14 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// mul_mv_q6_k
// mul_mv_q6_k_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mv_q6_k.cl.h"
#include "mul_mv_q6_k_f32.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mv_q6_k.cl");
const std::string kernel_src = read_file("mul_mv_q6_k_f32.cl");
#endif
backend_ctx->program_mul_mv_q6_K =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
@@ -1130,6 +1134,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// mul_mv_q6_k_f32_flat
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mv_q6_k_f32_flat.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mv_q6_k_f32_flat.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q6_K_f32_flat", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// mul_mv_q8_0_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -2919,6 +2940,50 @@ struct ggml_tensor_extra_cl_q8_0 {
}
};
struct ggml_tensor_extra_cl_q6_K {
// Lower 4 bits of quantized weights.
cl_mem ql = nullptr;
// Upper 2 bits of quantized weights.
cl_mem qh = nullptr;
// Scales for each block.
cl_mem s = nullptr;
// Scales for each super block.
cl_mem d = nullptr;
size_t size_ql = 0;
size_t size_qh = 0;
size_t size_s = 0;
size_t size_d = 0;
~ggml_tensor_extra_cl_q6_K() {
reset();
}
void reset() {
if (ql != nullptr) {
CL_CHECK(clReleaseMemObject(ql));
ql = nullptr;
}
if (qh != nullptr) {
CL_CHECK(clReleaseMemObject(qh));
qh = nullptr;
}
if (s != nullptr) {
CL_CHECK(clReleaseMemObject(s));
s = nullptr;
}
if (d != nullptr) {
CL_CHECK(clReleaseMemObject(d));
d = nullptr;
}
size_ql = 0;
size_qh = 0;
size_s = 0;
size_d = 0;
}
};
//------------------------------------------------------------------------------
// Backend API
//------------------------------------------------------------------------------
@@ -3465,6 +3530,12 @@ struct ggml_backend_opencl_buffer_context {
for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) {
delete e;
}
for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K) {
delete e;
}
for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) {
delete e;
}
}
ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() {
@@ -3527,6 +3598,21 @@ struct ggml_backend_opencl_buffer_context {
return extra;
}
ggml_tensor_extra_cl_q6_K * ggml_opencl_alloc_temp_tensor_extra_q6_K() {
ggml_tensor_extra_cl_q6_K * extra;
if (temp_tensor_extras_q6_K.empty()) {
extra = new ggml_tensor_extra_cl_q6_K();
} else {
extra = temp_tensor_extras_q6_K.back();
temp_tensor_extras_q6_K.pop_back();
}
temp_tensor_extras_q6_K_in_use.push_back(extra);
extra->reset();
return extra;
}
void reset() {
for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) {
temp_tensor_extras.push_back(e);
@@ -3547,6 +3633,11 @@ struct ggml_backend_opencl_buffer_context {
temp_tensor_extras_q8_0.push_back(e);
}
temp_tensor_extras_q8_0_in_use.clear();
for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) {
temp_tensor_extras_q6_K.push_back(e);
}
temp_tensor_extras_q6_K_in_use.clear();
}
// Pools for extras. Available extras are in `temp_tensor_extras`. Extras
@@ -3562,6 +3653,8 @@ struct ggml_backend_opencl_buffer_context {
std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4_in_use;
std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0;
std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0_in_use;
std::vector<ggml_tensor_extra_cl_q6_K *> temp_tensor_extras_q6_K;
std::vector<ggml_tensor_extra_cl_q6_K *> temp_tensor_extras_q6_K_in_use;
// The buffer_context is initially created by ggml_backend_buft_alloc_buffer
// before any tensor is initialized (at the beginning of alloc_tensor_range).
@@ -4068,6 +4161,92 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
return;
}
if (tensor->type == GGML_TYPE_Q6_K) {
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
// Allocate the new extra and create aliases from the original.
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
ggml_tensor_extra_cl_q6_K * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q6_K();
size_t size_ql = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;
size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/4;
size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/16;
size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
GGML_ASSERT(size_ql + size_qh + size_s + size_d == ggml_nbytes(tensor) &&
"Incorrect tensor size");
cl_int err;
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
ggml_nbytes(tensor), NULL, &err);
CL_CHECK(err);
CL_CHECK(clEnqueueWriteBuffer(
queue, data_device, CL_TRUE, 0,
ggml_nbytes(tensor), data, 0, NULL, NULL));
cl_buffer_region region;
// Subbuffer for ql
region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);
region.size = size_ql;
extra->ql = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
auto previous_origin = region.origin;
// Subbuffer for qh
region.origin = align_to(previous_origin + size_ql, backend_ctx->alignment);
region.size = size_qh;
extra->qh = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
previous_origin = region.origin;
// Subbuffer for scales
region.origin = align_to(previous_origin + size_qh, backend_ctx->alignment);
region.size = size_s;
extra->s = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
previous_origin = region.origin;
// Create subbuffer for d.
region.origin = align_to(previous_origin + size_s, backend_ctx->alignment);
region.size = size_d;
extra->d = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
previous_origin = region.origin;
// Flatten the weights
cl_kernel kernel = backend_ctx->kernel_convert_block_q6_K;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d));
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {64, 1, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clReleaseMemObject(data_device));
extra->size_ql = size_ql;
extra->size_qh = size_qh;
extra->size_s = size_s;
extra->size_d = size_d;
tensor->extra = extra;
return;
}
#endif // GGML_OPENCL_SOA_Q
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra;
@@ -4277,6 +4456,34 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {1, 1, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clEnqueueReadBuffer(
queue, data_device, CL_TRUE, offset,
size, data, 0, NULL, NULL));
CL_CHECK(clReleaseMemObject(data_device));
return;
}
if (tensor->type == GGML_TYPE_Q6_K) {
ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra;
cl_int err;
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
ggml_nbytes(tensor), NULL, &err);
CL_CHECK(err);
cl_kernel kernel = backend_ctx->kernel_restore_block_q6_K;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device));
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {1, 1, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
global_work_size, local_work_size, 0, NULL, &evt));
@@ -7765,6 +7972,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra;
ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra;
ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra;
ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra;
#endif
const int ne00 = src0 ? src0->ne[0] : 0;
@@ -8648,14 +8856,49 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
#ifdef GGML_OPENCL_SOA_Q
kernel = backend_ctx->kernel_mul_mv_q6_K_f32_flat;
if (backend_ctx->gpu_family == INTEL) {
nth0 = 16;
nth1 = 2;
ndst = 4;
} else if (backend_ctx->gpu_family == ADRENO) {
nth0 = 64;
nth1 = 2;
ndst = 4;
} else {
GGML_ASSERT(false && "TODO: Unknown GPU");
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r3));
#else
kernel = backend_ctx->kernel_mul_mv_q6_K_f32;
if (backend_ctx->gpu_family == INTEL) {
nth0 = 2;
nth1 = 16;
nth0 = 16;
nth1 = 2;
ndst = 1;
} else if (backend_ctx->gpu_family == ADRENO) {
nth0 = 2;
nth1 = 64;
nth0 = 64;
nth1 = 2;
ndst = 1;
} else {
GGML_ASSERT(false && "TODO: Unknown GPU");
}
@@ -8675,6 +8918,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3));
#endif // GGML_OPENCL_SOA_Q
break;
case GGML_TYPE_MXFP4: {
#ifdef GGML_OPENCL_SOA_Q
@@ -8777,7 +9021,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
} else if (src0t == GGML_TYPE_Q5_K) {
GGML_ASSERT(false && "not implemented");
} else if (src0t == GGML_TYPE_Q6_K) {
size_t global_work_size[] = {(size_t)(ne01+1)/2*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);

View File

@@ -46,6 +46,16 @@ struct block_q4_0
uint8_t qs[QK4_0 / 2];
};
//------------------------------------------------------------------------------
// block_q6_K
//------------------------------------------------------------------------------
struct block_q6_K {
uint8_t ql[QK_K/2]; // quants, lower 4 bits
uint8_t qh[QK_K/4]; // quants, upper 2 bits
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
half d; // super-block scale
};
//------------------------------------------------------------------------------
// kernel_convert_block_q4_0
// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA).
@@ -263,3 +273,63 @@ kernel void kernel_restore_block_q8_0(
b->qs[i] = q[i];
}
}
//------------------------------------------------------------------------------
// kernel_convert_block_q6_K
// Convert the block_q6_K format to 3 separate arrays (AOS -> SOA).
// This kernel does not deshuffle the bits.
// Each thread processes a super block.
//------------------------------------------------------------------------------
kernel void kernel_convert_block_q6_K(
global struct block_q6_K * src0,
global uchar * dst_ql,
global uchar * dst_qh,
global char * dst_s,
global half * dst_d
) {
global struct block_q6_K * b = (global struct block_q6_K *) src0 + get_global_id(0);
global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0);
global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0);
global char * s = (global char *) dst_s + QK_K/16*get_global_id(0);
global half * d = (global half *) dst_d + get_global_id(0);
*d = b->d;
for (int i = 0; i < QK_K/2; ++i) {
ql[i] = b->ql[i];
}
for (int i = 0; i < QK_K/4; ++i) {
qh[i] = b->qh[i];
}
for (int i = 0; i < QK_K/16; ++i) {
s[i] = b->scales[i];
}
}
// Restore block_q6_K from flattened arrays.
// Each thread processes a super block.
kernel void kernel_restore_block_q6_K(
global uchar * dst_ql,
global uchar * dst_qh,
global char * dst_s,
global half * dst_d,
global struct block_q6_K * dst
) {
global struct block_q6_K * b = (global struct block_q6_K *) dst + get_global_id(0);
global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0);
global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0);
global char * s = (global char *) dst_s + QK_K/16*get_global_id(0);
global half * d = (global half *) dst_d + get_global_id(0);
b->d = *d;
for (int i = 0; i < QK_K/2; ++i) {
b->ql[i] = ql[i];
}
for (int i = 0; i < QK_K/4; ++i) {
b->qh[i] = qh[i];
}
for (int i = 0; i < QK_K/16; ++i) {
b->scales[i] = s[i];
}
}

View File

@@ -0,0 +1,194 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
//------------------------------------------------------------------------------
// kernel_mul_mv_q6_K_f32_flat
//------------------------------------------------------------------------------
#define Q6_K_MASK1 0x03
#define Q6_K_MASK2 0x0C
#define Q6_K_MASK3 0x30
#define Q6_K_MASK4 0xC0
#define QK_K 256
inline float block_q_6_K_dot_y_flat(
global uchar * blk_ql,
global uchar * blk_qh,
global char * blk_scales,
global half * blk_d,
global float * yy,
int ib,
int ip,
int is,
int l0
) {
int y_offset = 128*ip + l0;
int q_offset_l = 64*ip + l0;
int q_offset_h = 32*ip + l0;
global uchar * q1 = blk_ql + ib*128 + q_offset_l;
global uchar * q2 = q1 + QK_K/8;
global uchar * qh = blk_qh + ib*64 + q_offset_h;
global char * sc = blk_scales + ib*16 + is;
global float * y = yy + ib * QK_K + y_offset;
float dall = blk_d[ib];
float sumf = 0;
float4 sums = {0.f, 0.f, 0.f, 0.f};
sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & Q6_K_MASK1) << 4)) - 32.f);
sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & Q6_K_MASK2) << 2)) - 32.f);
sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & Q6_K_MASK3) << 0)) - 32.f);
sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & Q6_K_MASK4) >> 2)) - 32.f);
sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & Q6_K_MASK1) << 4)) - 32.f);
sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & Q6_K_MASK2) << 2)) - 32.f);
sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & Q6_K_MASK3) << 0)) - 32.f);
sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & Q6_K_MASK4) >> 2)) - 32.f);
sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & Q6_K_MASK1) << 4)) - 32.f);
sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & Q6_K_MASK2) << 2)) - 32.f);
sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & Q6_K_MASK3) << 0)) - 32.f);
sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & Q6_K_MASK4) >> 2)) - 32.f);
sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & Q6_K_MASK1) << 4)) - 32.f);
sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & Q6_K_MASK2) << 2)) - 32.f);
sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & Q6_K_MASK3) << 0)) - 32.f);
sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & Q6_K_MASK4) >> 2)) - 32.f);
sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]);
return sumf;
}
#undef N_DST
#undef N_SIMDGROUP
#undef N_SIMDWIDTH
#ifdef INTEL_GPU
#define N_DST 4
#define N_SIMDGROUP 2
#define N_SIMDWIDTH 16
#elif defined (ADRENO_GPU)
#define N_DST 4
#define N_SIMDGROUP 2
#define N_SIMDWIDTH 64
#endif
#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q6_K_f32_flat(
global uchar * src0_ql,
global uchar * src0_qh,
global char * src0_s,
global half * src0_d,
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
int nb = ne00/QK_K;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int i12 = im%ne12;
int i13 = im/ne12;
int first_row = (N_SIMDGROUP * r0 + get_sub_group_id()) * N_DST;
ulong offset_src0 = first_row*nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
ulong offset_src0_ql = offset_src0 * 128;
ulong offset_src0_qh = offset_src0 * 64;
ulong offset_src0_s = offset_src0 * 16;
ulong offset_src0_d = offset_src0;
global uchar * blk_ql = (global uchar *) src0_ql + offset_src0_ql;
global uchar * blk_qh = (global uchar *) src0_qh + offset_src0_qh;
global char * blk_scales = (global char *) src0_s + offset_src0_s;
global half * blk_d = (global half *) src0_d + offset_src0_d;
global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1;
int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0
int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1
int ip = tid/8; // first or second half of (super) block (0 or 1)
int il = tid%8; // each half has 8 parts, one per scale
int n = 4; // 4 scales at a time (and 4 sums)
int l0 = n*il; // offset into half-block, 0..28
int is = 8*ip + l0/16; // 0, 1, 8, 9
float4 sumf = 0;
for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {
if (first_row + 0 < ne01) {
sumf.s0 += block_q_6_K_dot_y_flat(blk_ql + 0*nb*128, blk_qh + 0*nb*64, blk_scales + 0*nb*16, blk_d + 0*nb, yy, ib, ip, is, l0);
}
if (first_row + 1 < ne01) {
sumf.s1 += block_q_6_K_dot_y_flat(blk_ql + 1*nb*128, blk_qh + 1*nb*64, blk_scales + 1*nb*16, blk_d + 1*nb, yy, ib, ip, is, l0);
}
if (first_row + 2 < ne01) {
sumf.s2 += block_q_6_K_dot_y_flat(blk_ql + 2*nb*128, blk_qh + 2*nb*64, blk_scales + 2*nb*16, blk_d + 2*nb, yy, ib, ip, is, l0);
}
if (first_row + 3 < ne01) {
sumf.s3 += block_q_6_K_dot_y_flat(blk_ql + 3*nb*128, blk_qh + 3*nb*64, blk_scales + 3*nb*16, blk_d + 3*nb, yy, ib, ip, is, l0);
}
}
float4 tot = (float4)(
sub_group_reduce_add(sumf.s0),
sub_group_reduce_add(sumf.s1),
sub_group_reduce_add(sumf.s2),
sub_group_reduce_add(sumf.s3)
);
if (get_sub_group_local_id() == 0) {
if (first_row + 0 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
}
if (first_row + 1 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
}
if (first_row + 2 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
}
if (first_row + 3 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
}
}
}

View File

@@ -2173,13 +2173,6 @@ llm_graph_cb llama_context::graph_get_cb() const {
ggml_set_name(cur, name);
}
if (!cparams.offload_kqv) {
if (strcmp(name, "kqv_merged_cont") == 0) {
// all nodes between the KV store and the attention output are run on the CPU
ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu);
}
}
// norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
// FIXME: fix in ggml_backend_sched
const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer;

View File

@@ -1630,6 +1630,11 @@ ggml_tensor * llm_graph_context::build_attn_mha(
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
if (!cparams.offload_kqv) {
// all nodes between the KV store and the attention output are run on the CPU
ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
}
ggml_flash_attn_ext_add_sinks(cur, sinks);
ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);

View File

@@ -1737,6 +1737,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
switch (hparams.n_layer) {
case 27: type = LLM_TYPE_16B; break;
case 47: type = LLM_TYPE_30B_A3B; break;
case 60: type = LLM_TYPE_236B; break;
case 61: type = LLM_TYPE_671B; break;
default: type = LLM_TYPE_UNKNOWN;

View File

@@ -8216,8 +8216,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (int nh : { 4, }) {
for (int nr3 : { 1, 3, }) {
if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
for (int nr2 : { 1, 4, 16 }) {
if (nr2 == 16 && hsk != 128) continue;
for (int nr2 : { 1, 4, 12 }) {
if (nr2 == 12 && hsk != 128) continue;
//for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) {
for (int kv : { 113, 512, 1024, }) {
if (nr2 != 1 && kv != 512) continue;