Compare commits

...

3 Commits

Author SHA1 Message Date
Ruben Ortlam
715ed28683 use scalar sums 2026-03-07 22:11:40 +01:00
Ruben Ortlam
a9435151db apply scales inline 2026-03-07 14:56:25 +01:00
Ruben Ortlam
d1f8bbd085 vulkan: add int8 coopmat quantized matmul shader 2026-03-07 14:43:21 +01:00
4 changed files with 517 additions and 4 deletions

View File

@@ -3201,6 +3201,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
const uint32_t itm_l = device->coopmat_int_support ? device->coopmat_int_m : 4;
const uint32_t itm_m = device->coopmat_int_support ? device->coopmat_int_m : 4;
const uint32_t itm_s = device->coopmat_int_support ? device->coopmat_int_m : 2;
const uint32_t itn_l = device->coopmat_int_support ? device->coopmat_int_n : 4;
const uint32_t itn_m = device->coopmat_int_support ? device->coopmat_int_n : 2;
const uint32_t itn_s = device->coopmat_int_support ? device->coopmat_int_n : 1;
const uint32_t itk_l = device->coopmat_int_support ? device->coopmat_int_k : 1;
const uint32_t itk_m = device->coopmat_int_support ? device->coopmat_int_k : 1;
const uint32_t itk_s = device->coopmat_int_support ? device->coopmat_int_k : 1;
const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32;
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
@@ -3212,9 +3222,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
// Integer MMQ has a smaller shared memory profile, but heavier register use
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, 2, 1, 1, subgroup_size_8 };
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, itm_l, itn_l, itk_l, subgroup_size_8 };
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, itm_m, itn_m, itk_m, subgroup_size_8 };
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, itm_s, itn_s, itk_s, subgroup_size_8 };
// K-quants use even more registers, mitigate by setting WMITER to 1
l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
@@ -3520,6 +3530,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
// Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->coopmat_acc_f16_support) { \
@@ -3529,6 +3547,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
} \
#define CREATE_MMQ2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MMQ(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MMQ(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
@@ -3561,6 +3583,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
if (device->coopmat_int_support) {
CREATE_MMQ2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
}
} else {
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -3583,6 +3609,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
if (device->coopmat_int_support) {
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
}
}
GGML_ASSERT(device->subgroup_ballot);
@@ -3616,6 +3646,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
#undef CREATE_MMQ2
#undef CREATE_MMQ
#undef CREATE_MM2
#undef CREATE_MM
} else
@@ -7316,7 +7348,8 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0;
bool quantize_y = (ctx->device->integer_dot_product || ctx->device->coopmat_int_support) &&
src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0;
// Check for mmq first
vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;

View File

@@ -0,0 +1,390 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_KHR_cooperative_matrix : require
#extension GL_KHR_memory_scope_semantics : enable
#if defined(MUL_MAT_ID_USE_SUBGROUPS)
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#endif
#ifdef MUL_MAT_ID
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#endif
#include "types.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
layout (binding = 2) writeonly buffer D4 {D_TYPE_VEC4 data_dv4[];};
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
layout (binding = 4) readonly buffer Counts {int data_expert_count[];};
#endif
layout (push_constant) uniform parameter
{
uint M;
uint N;
uint K;
uint stride_a;
uint stride_b;
uint stride_d;
uint batch_stride_a;
uint batch_stride_b;
uint batch_stride_d;
#ifdef MUL_MAT_ID
uint nei0;
uint nei1;
uint nbi1;
uint ne11;
#else
uint base_work_group_z;
uint num_batches;
uint k_split;
uint ne02;
uint ne12;
uint broadcast2;
uint broadcast3;
#endif
} p;
layout (constant_id = 0) const uint BLOCK_SIZE = 64;
layout (constant_id = 1) const uint BM = 64;
layout (constant_id = 2) const uint BN = 64;
// layout (constant_id = 3) const uint BK = 32;
layout (constant_id = 4) const uint WM = 32;
layout (constant_id = 5) const uint WN = 32;
layout (constant_id = 6) const uint WMITER = 2;
layout (constant_id = 7) const uint TM = 16;
layout (constant_id = 8) const uint TN = 16;
layout (constant_id = 9) const uint TK = 16;
layout (constant_id = 10) const uint WARP = 32;
#define BK 32
const uint shmem_stride = (BK / 4) + 4;
// Shared memory cache
shared uint32_t buf_a_qs[BM * shmem_stride];
shared float16_t buf_a_d[BM];
shared uint32_t buf_b_qs[BN * shmem_stride];
shared float16_t buf_b_d[BN];
#define LOAD_VEC_A (4 * QUANT_R)
#define LOAD_VEC_B 16
#define NUM_WARPS (BLOCK_SIZE / WARP)
shared ivec4 coopmat_stage[TM * TN * NUM_WARPS / 4];
#include "mul_mm_id_funcs.glsl"
#include "mul_mmq_cm1_funcs.glsl"
void main() {
const uint ic = gl_WorkGroupID.y;
#ifdef MUL_MAT_ID
const uint expert_idx = gl_WorkGroupID.z;
if (ic * BN >= data_expert_count[expert_idx]) {
return;
}
#endif
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
#ifndef MUL_MAT_ID
const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12;
const uint i03 = i13 / p.broadcast3;
const uint i02 = i12 / p.broadcast2;
const uint batch_idx_a = i03 * p.ne02 + i02;
#endif
const uint blocks_m = (p.M + BM - 1) / BM;
const uint ir = gl_WorkGroupID.x % blocks_m;
const uint ik = gl_WorkGroupID.x / blocks_m;
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const uint WSUBM = WM / WMITER;
const uint WSUBN = WN / WNITER;
const uint warp_i = gl_SubgroupID;
const uint tiw = gl_SubgroupInvocationID;
const uint cms_per_row = WM / TM;
const uint cms_per_col = WN / TN;
const uint storestride = WARP / TM;
const uint store_r = tiw % TM;
const uint store_c = tiw / TM;
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK;
const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
#ifdef MUL_MAT_ID
#ifdef MUL_MAT_ID_USE_SUBGROUPS
if (bitCount(p.nei0) == 1) {
load_row_ids(expert_idx, true, ic);
} else {
load_row_ids(expert_idx, false, ic);
}
#else
_ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
if (_ne1 >= ic * BN) {
row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
}
_ne1++;
}
}
}
barrier();
#endif
// Workgroup has no work
if (ic * BN >= _ne1) return;
#endif
#ifdef MUL_MAT_ID
const uint start_k = 0;
const uint end_k = p.K;
#else
const uint start_k = ik * p.k_split;
const uint end_k = min(p.K, (ik + 1) * p.k_split);
#endif
uint pos_a_ib =
#ifdef MUL_MAT_ID
expert_idx * (p.batch_stride_a / BK) +
#else
batch_idx_a * (p.batch_stride_a / BK) +
#endif
(ir * BM * p.stride_a + start_k) / BK;
#ifdef MUL_MAT_ID
uint pos_b_ib = 0;
#else
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
#endif
coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result[cms_per_row * cms_per_col];
const uint accs_per_thread = (WM * WN) / WARP / 4;
ACC_TYPE_VEC4 sums[accs_per_thread];
[[unroll]] for (uint i = 0; i < accs_per_thread; i++) {
sums[i] = ACC_TYPE_VEC4(0.0f);
}
const uint chunks_per_thread_per_tile = (TM * TN) / (WARP * 4);
for (uint block = start_k; block < end_k; block += BK) {
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
const uint buf_ib = loadc_a + l;
const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
const uint iqs = loadr_a;
block_a_to_shmem(buf_ib, ib, iqs);
}
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
const uint buf_ib = loadc_b + l;
#ifdef MUL_MAT_ID
const u16vec2 row_idx = row_ids[buf_ib];
const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK;
#else
const uint ib = pos_b_ib + buf_ib * p.stride_b / BK;
#endif
const uint iqs = loadr_b;
block_b_to_shmem(buf_ib, ib, iqs);
}
barrier();
pos_a_ib += 1;
pos_b_ib += 1;
[[unroll]] for (uint idx = 0; idx < cms_per_row * cms_per_col; idx++) {
cm_result[idx] = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
}
// Calculate quants
[[unroll]] for (uint i = 0; i < BK; i += TK) {
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
coopMatLoad(cache_a, buf_a_qs, (warp_r * WM + cm_row * TM) * shmem_stride + i / 4, shmem_stride, gl_CooperativeMatrixLayoutRowMajor);
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopMatLoad(cache_b, buf_b_qs, (warp_c * WN + cm_col * TN) * shmem_stride + i / 4, shmem_stride, gl_CooperativeMatrixLayoutColumnMajor);
cm_result[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, cm_result[cm_col * cms_per_row + cm_row]);
}
}
}
// Store to shmem
const uint subgroup_vec_stride = (TM * TN) / 4;
const uint subgroup_offset = warp_i * subgroup_vec_stride;
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const uint tile_idx = cm_col * cms_per_row + cm_row;
coopMatStore(cm_result[tile_idx], coopmat_stage, subgroup_offset, TM / 4, gl_CooperativeMatrixLayoutColumnMajor);
controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
// Each thread grabs chunks and applies the scales
[[unroll]] for (uint chunk = 0; chunk < chunks_per_thread_per_tile; chunk++) {
const uint local_chunk = chunk * WARP + tiw;
const uint col_local = local_chunk / (TM / 4);
const uint row_group = local_chunk % (TM / 4);
const uint row0_local = row_group * 4;
const ivec4 qs = coopmat_stage[subgroup_offset + col_local * (TM / 4) + row_group];
const uint a_row0 = warp_r * WM + cm_row * TM + row0_local;
const uint b_col = warp_c * WN + cm_col * TN + col_local;
const ACC_TYPE_VEC4 da = ACC_TYPE_VEC4(buf_a_d[a_row0], buf_a_d[a_row0+1], buf_a_d[a_row0+2], buf_a_d[a_row0+3]);
const ACC_TYPE db = ACC_TYPE(buf_b_d[b_col]);
sums[tile_idx * chunks_per_thread_per_tile + chunk] += ACC_TYPE_VEC4(qs) * da * db;
}
}
}
barrier();
}
const uint dr = ir * BM + warp_r * WM;
const uint dc = ic * BN + warp_c * WN;
const bool is_aligned = p.stride_d % 4 == 0;
#ifdef MUL_MAT_ID
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const uint tile_idx = cm_col * cms_per_row + cm_row;
[[unroll]] for (uint chunk = 0; chunk < chunks_per_thread_per_tile; chunk++) {
const uint local_chunk = chunk * WARP + tiw;
const uint col_local = local_chunk / (TM / 4);
const uint row_group = local_chunk % (TM / 4);
const uint row0_local = row_group * 4;
const uint row_i = dc + cm_col * TN + col_local;
if (row_i >= _ne1) break;
const uint row0_g = dr + cm_row * TM + row0_local;
const u16vec2 row_idx = row_ids[row_i - ic * BN];
const uint store_offset = row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + row0_g;
const uint acc_idx = tile_idx * chunks_per_thread_per_tile + chunk;
if (row0_g + 3 < p.M && is_aligned && (store_offset % 4) == 0) {
data_dv4[store_offset / 4] = D_TYPE_VEC4(sums[acc_idx]);
} else if (row0_g + 3 < p.M) {
const ACC_TYPE_VEC4 vals = sums[acc_idx];
data_d[store_offset ] = D_TYPE(vals.x);
data_d[store_offset + 1] = D_TYPE(vals.y);
data_d[store_offset + 2] = D_TYPE(vals.z);
data_d[store_offset + 3] = D_TYPE(vals.w);
} else if (row0_g + 2 < p.M) {
const ACC_TYPE_VEC4 vals = sums[acc_idx];
data_d[store_offset ] = D_TYPE(vals.x);
data_d[store_offset + 1] = D_TYPE(vals.y);
data_d[store_offset + 2] = D_TYPE(vals.z);
} else if (row0_g + 1 < p.M) {
const ACC_TYPE_VEC4 vals = sums[acc_idx];
data_d[store_offset ] = D_TYPE(vals.x);
data_d[store_offset + 1] = D_TYPE(vals.y);
} else if (row0_g < p.M) {
const ACC_TYPE_VEC4 vals = sums[acc_idx];
data_d[store_offset] = D_TYPE(vals.x);
}
}
}
}
#else
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const uint tile_idx = cm_col * cms_per_row + cm_row;
[[unroll]] for (uint chunk = 0; chunk < chunks_per_thread_per_tile; chunk++) {
const uint local_chunk = chunk * WARP + tiw;
const uint col_local = local_chunk / (TM / 4);
const uint row_group = local_chunk % (TM / 4);
const uint row0_local = row_group * 4;
const uint col_g = dc + cm_col * TN + col_local;
if (col_g >= p.N) break;
const uint row0_g = dr + cm_row * TM + row0_local;
const uint store_offset = offsets + col_g * p.stride_d + row0_g;
const uint acc_idx = tile_idx * chunks_per_thread_per_tile + chunk;
if (row0_g + 3 < p.M && is_aligned && (store_offset % 4) == 0) {
data_dv4[store_offset / 4] = D_TYPE_VEC4(sums[acc_idx]);
} else if (row0_g + 3 < p.M) {
const ACC_TYPE_VEC4 vals = sums[acc_idx];
data_d[store_offset ] = D_TYPE(vals.x);
data_d[store_offset + 1] = D_TYPE(vals.y);
data_d[store_offset + 2] = D_TYPE(vals.z);
data_d[store_offset + 3] = D_TYPE(vals.w);
} else if (row0_g + 2 < p.M) {
const ACC_TYPE_VEC4 vals = sums[acc_idx];
data_d[store_offset ] = D_TYPE(vals.x);
data_d[store_offset + 1] = D_TYPE(vals.y);
data_d[store_offset + 2] = D_TYPE(vals.z);
} else if (row0_g + 1 < p.M) {
const ACC_TYPE_VEC4 vals = sums[acc_idx];
data_d[store_offset ] = D_TYPE(vals.x);
data_d[store_offset + 1] = D_TYPE(vals.y);
} else if (row0_g < p.M) {
const ACC_TYPE_VEC4 vals = sums[acc_idx];
data_d[store_offset] = D_TYPE(vals.x);
}
}
}
}
#endif // MUL_MAT_ID
}

View File

@@ -0,0 +1,85 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#include "types.glsl"
// Each iqs value maps to a 32-bit integer
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
// 2-byte loads for Q4_0 blocks (18 bytes)
// 4-byte loads for Q4_1 blocks (20 bytes)
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
#ifdef DATA_A_Q4_0
const uint32_t vui = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
data_a_packed16[ib].qs[iqs * 2 + 1]));
#else // DATA_A_Q4_1
const uint32_t vui = data_a_packed32[ib].qs[iqs];
#endif
uint32_t lo4 = vui & 0x0F0F0F0F;
uint32_t hi4 = (vui >> 4) & 0x0F0F0F0F;
// subtract 8 from each byte
lo4 = ((lo4 | 0x80808080) - 0x08080808) ^ 0x80808080;
hi4 = ((hi4 | 0x80808080) - 0x08080808) ^ 0x80808080;
buf_a_qs[buf_ib * shmem_stride + iqs ] = lo4;
buf_a_qs[buf_ib * shmem_stride + iqs + 4] = hi4;
if (iqs == 0) {
#ifdef DATA_A_Q4_0
buf_a_d[buf_ib] = FLOAT_TYPE(data_a_packed16[ib].d);
#else // DATA_A_Q4_1
#endif
}
}
#endif
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
// 2-byte loads for Q5_0 blocks (22 bytes)
// 4-byte loads for Q5_1 blocks (24 bytes)
}
#endif
#if defined(DATA_A_Q8_0)
// 2-byte loads for Q8_0 blocks (34 bytes)
#endif
#if defined(DATA_A_MXFP4)
// 1-byte loads for mxfp4 blocks (17 bytes)
#endif
// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
#if defined(DATA_A_Q2_K)
// 4-byte loads for Q2_K blocks (84 bytes)
#endif
#if defined(DATA_A_Q3_K)
// 2-byte loads for Q3_K blocks (110 bytes)
#endif
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
#endif
#if defined(DATA_A_Q6_K)
// 2-byte loads for Q6_K blocks (210 bytes)
#endif
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_outer = ib / 4;
const uint ib_inner = ib % 4;
if (iqs == 0) {
// Divide by TK for matmul scale application
buf_b_d[buf_ib] = data_b[ib_outer].ds[ib_inner].x;
}
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b_qs[buf_ib * shmem_stride + iqs * 4 ] = values.x;
buf_b_qs[buf_ib * shmem_stride + iqs * 4 + 1] = values.y;
buf_b_qs[buf_ib * shmem_stride + iqs * 4 + 2] = values.z;
buf_b_qs[buf_ib * shmem_stride + iqs * 4 + 3] = values.w;
}

View File

@@ -447,6 +447,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float";
base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2";
base_dict["ACC_TYPE_VEC4"] = f16acc ? "f16vec4" : "vec4";
if (f16acc) {
base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
}
@@ -591,6 +592,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
}
#endif
if (coopmat && tname == "q4_0") {
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq_cm1.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"}, {"D_TYPE_VEC4", "vec4"}}), fp16, coopmat, coopmat2, f16acc);
}
}
}