Compare commits

...

1 Commits

Author SHA1 Message Date
Ruben Ortlam
abb9f3c42b vulkan: fix MMQ shader push constants and multi-dispatch (#19732) 2026-02-19 14:59:16 +01:00

View File

@@ -57,6 +57,8 @@ layout (push_constant) uniform parameter
uint nbi1;
uint ne11;
#else
uint base_work_group_z;
uint num_batches;
uint k_split;
uint ne02;
uint ne12;
@@ -108,7 +110,7 @@ void main() {
const uint ic = gl_WorkGroupID.y;
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
const uint expert_idx = gl_WorkGroupID.z;
if (ic * BN >= data_expert_count[expert_idx]) {
return;
}
@@ -118,7 +120,7 @@ void main() {
#endif
#ifndef MUL_MAT_ID
const uint batch_idx = gl_GlobalInvocationID.z;
const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12;
@@ -276,7 +278,7 @@ void main() {
const uint dc = ic * BN + warp_c * WN;
#ifndef MUL_MAT_ID
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
#endif
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {