use row_split when Br >= 4, change reductions to use shared memory if row_split == 1

This commit is contained in:
Ruben Ortlam
2026-02-05 12:51:59 +01:00
parent 343d285b98
commit affe132f53
2 changed files with 48 additions and 38 deletions

View File

@@ -3173,7 +3173,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
break;
default:
wg_size = scalar_flash_attention_workgroup_size;
wg_size = device->subgroup_size * 4;
break;
}
@@ -3201,15 +3201,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (path == FAPATH) { \
if (aligned) { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, device->subgroup_size); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, device->subgroup_size); \
} \
} else { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, device->subgroup_size); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, device->subgroup_size); \
} \
} \
} \
@@ -15933,7 +15933,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph *
ggml_vk_print_graph_origin(tensor, done);
}
if (avg_err > 0.5 || std::isnan(avg_err)) {
if (avg_err > 0.01 || std::isnan(avg_err)) {
std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
if (src0 != nullptr) {

View File

@@ -19,10 +19,11 @@
const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
const uint32_t row_split = 4;
const uint32_t row_split = (Br < 4) ? 1 : 4;
const uint32_t rows_per_thread = Br / row_split;
const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
const uint32_t num_subgroups = WorkGroupSize / SubGroupSize;
layout (binding = 0) readonly buffer Q {float data_q[];};
@@ -42,8 +43,10 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in ACC_
return elem;
}
shared float tmpsh[WorkGroupSize];
shared vec4 tmpshv4[WorkGroupSize];
const uint32_t tmpsh_reduction_size = row_split == 1 ? num_subgroups * D_split : 0;
const uint32_t tmpsh_size = tmpsh_reduction_size > 4 ? tmpsh_reduction_size : 4;
shared float tmpsh[tmpsh_size];
shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size];
shared FLOAT_TYPE masksh[Bc][Br];
@@ -269,63 +272,70 @@ void main() {
barrier();
}
// prevent race on tmpsh
barrier();
// reduce across threads
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
float rowmaxf, eMf;
float rowmaxf = Mf[r];
tmpsh[tid] = Mf[r];
// Compute max across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]);
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s));
}
if (row_split == 1) {
// Reduce inside workgroup with shmem
barrier();
if (gl_SubgroupInvocationID == d_tid) {
tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf;
}
barrier();
rowmaxf = max(max(max(tmpsh[0 * D_split + d_tid],
tmpsh[1 * D_split + d_tid]),
tmpsh[2 * D_split + d_tid]),
tmpsh[3 * D_split + d_tid]);
}
rowmaxf = tmpsh[d_tid];
barrier();
float Moldf = Mf[r];
// M = max(rowmax, Mold)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf, Moldf);
eMf = exp(Moldf - Mf[r]);
float eMf = exp(Moldf - Mf[r]);
Lf[r] = eMf*Lf[r];
tmpsh[tid] = Lf[r];
// Compute sum across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s];
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
Lf[r] += subgroupShuffleXor(Lf[r], s);
}
if (row_split == 1) {
barrier();
if (gl_SubgroupInvocationID == d_tid) {
tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r];
}
barrier();
Lf[r] = tmpsh[0 * D_split + d_tid] +
tmpsh[1 * D_split + d_tid] +
tmpsh[2 * D_split + d_tid] +
tmpsh[3 * D_split + d_tid];
}
Lf[r] = tmpsh[d_tid];
barrier();
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = ACC_TYPE(eMf) * Of[r][d];
tmpshv4[tid] = Of[r][d];
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
Of[r][d] += ACC_TYPEV4(tmpshv4[tid + s]);
tmpshv4[tid] = Of[r][d];
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
Of[r][d] += subgroupShuffleXor(Of[r][d], s);
}
if (row_split == 1) {
barrier();
if (gl_SubgroupInvocationID == d_tid) {
tmpsh_accv4[gl_SubgroupID * D_split + d_tid] = Of[r][d];
}
barrier();
Of[r][d] = tmpsh_accv4[0 * D_split + d_tid] +
tmpsh_accv4[1 * D_split + d_tid] +
tmpsh_accv4[2 * D_split + d_tid] +
tmpsh_accv4[3 * D_split + d_tid];
}
Of[r][d] = ACC_TYPEV4(tmpshv4[d_tid]);
barrier();
}
}