Compare commits

...

2 Commits

Author SHA1 Message Date
Ruben Ortlam
7ded1269ab unify matmul_id shader selection 2026-03-12 14:55:12 +01:00
Ruben Ortlam
664dfc7730 vulkan: unify matmul shader selection 2026-03-12 14:47:18 +01:00

View File

@@ -4939,7 +4939,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
} else if (device->vendor_id == VK_VENDOR_ID_INTEL) {
device->shader_core_count = ggml_vk_intel_shader_core_count(device->physical_device);
} else {
device->shader_core_count = 0;
// Set reasonable default when actual count is not known
device->shader_core_count = 16;
}
device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
@@ -6942,7 +6943,7 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m,
}
uint32_t split_k = 1;
if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) {
if (m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) {
// If k is 'large' and the SMs will fill less than halfway, use split_k.
uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
@@ -6979,40 +6980,34 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m,
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
if (ctx->device->coopmat2) {
const uint32_t shader_core_count = ctx->device->shader_core_count;
const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]);
const uint32_t tiles_m = CEIL_DIV(m, mmp->a_m->wg_denoms[0]) * CEIL_DIV(n, mmp->a_m->wg_denoms[1]);
const bool mm_l = ctx->device->mul_mat_l[src0_type];
const bool mm_m = ctx->device->mul_mat_m[src0_type];
const bool mm_s = ctx->device->mul_mat_s[src0_type];
// Use large shader when the N dimension is greater than the medium shader's tile size
uint32_t crossover_large = mmp->m->wg_denoms[1];
const uint32_t shader_core_count = ctx->device->shader_core_count;
const uint32_t tiles_l = mm_l ? CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]) : 0;
const uint32_t tiles_m = mm_m ? CEIL_DIV(m, mmp->a_m->wg_denoms[0]) * CEIL_DIV(n, mmp->a_m->wg_denoms[1]) : 0;
// Prefer large over medium if either:
// - medium or large tiles would overfill the GPU
// - large tiles with a split_k==3 fits in the GPU and medium tiles with split_k==2 does not
// (medium with split_k==2 is probably better if it fits - more workgroups running and less split_k overhead)
bool prefer_large = tiles_m > shader_core_count || tiles_l > shader_core_count ||
// split_k==3 with large tiles likely better than medium tiles with no split_k.
(tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2);
// Use large shader when the N dimension is greater than the medium shader's tile size
uint32_t crossover_large = mm_m ? mmp->m->wg_denoms[1] : 0;
if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
return aligned ? mmp->a_l : mmp->l;
}
// Use medium shader when the N dimension is greater than the small shader's tile size
uint32_t crossover_medium = mmp->s->wg_denoms[1];
if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_s : mmp->s;
// Prefer large over medium if either:
// - medium or large tiles would overfill the GPU
// - large tiles with a split_k==3 fits in the GPU and medium tiles with split_k==2 does not
// (medium with split_k==2 is probably better if it fits - more workgroups running and less split_k overhead)
bool prefer_large = tiles_m > shader_core_count || tiles_l > shader_core_count ||
// split_k==3 with large tiles likely better than medium tiles with no split_k.
(tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2);
if ((mm_l && (n > crossover_large && prefer_large)) || (!mm_m && !mm_s)) {
return aligned ? mmp->a_l : mmp->l;
}
if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) {
return aligned ? mmp->a_s : mmp->s;
}
if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) {
// Use medium shader when the N dimension is greater than the small shader's tile size
uint32_t crossover_medium = mmp->s->wg_denoms[1];
if ((mm_m && (n > crossover_medium)) || !mm_s) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_l : mmp->l;
return aligned ? mmp->a_s : mmp->s;
GGML_UNUSED(src1_type);
}
@@ -7074,27 +7069,21 @@ static void ggml_vk_matmul(
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
if (ctx->device->coopmat2) {
// Use large shader when the N dimension is greater than the medium shader's tile size
uint32_t crossover_large = mmp->m->wg_denoms[1];
if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
return aligned ? mmp->a_l : mmp->l;
}
// Use medium shader when the N dimension is greater than the small shader's tile size
uint32_t crossover_medium = mmp->s->wg_denoms[1];
if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_s : mmp->s;
}
const bool mm_l = ctx->device->mul_mat_id_l[src0_type];
const bool mm_m = ctx->device->mul_mat_id_m[src0_type];
const bool mm_s = ctx->device->mul_mat_id_s[src0_type];
if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) {
return aligned ? mmp->a_s : mmp->s;
// Use large shader when the N dimension is greater than the medium shader's tile size
const uint32_t crossover_large = mm_m ? mmp->m->wg_denoms[1] : (mm_s ? mmp->s->wg_denoms[1] : 0);
if ((mm_l && (n > crossover_large)) || (!mm_m && !mm_s)) {
return aligned ? mmp->a_l : mmp->l;
}
if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) {
// Use medium shader when the N dimension is greater than the small shader's tile size
const uint32_t crossover_medium = mm_s ? mmp->s->wg_denoms[1] : 0;
if ((mm_m && (n > crossover_medium)) || !mm_s) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_l : mmp->l;
return aligned ? mmp->a_s : mmp->s;
}
static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
@@ -8899,7 +8888,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const uint32_t shader_core_count_multiplier = (ctx->device->vendor_id == VK_VENDOR_ID_INTEL && ctx->device->architecture != INTEL_XE2) ? 2 : 1;
// Use a placeholder core count if one isn't available. split_k is a big help for perf.
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16;
const uint32_t shader_core_count = ctx->device->shader_core_count * shader_core_count_multiplier;
const uint32_t Br = fa_pipeline_state.Br;
const uint32_t Bc = fa_pipeline_state.Bc;
@@ -9057,7 +9046,7 @@ static vk_conv_shapes ggml_vk_conv_select_shape(ggml_backend_vk_context * ctx, u
// We can't query number of shader cores on Intel, use 32 as a placeholder
// so small convolutions will still choose a smaller tile.
const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32;
const uint32_t shader_core_count = ctx->device->shader_core_count;
if (K > 64 && n_tiles(CONV_SHAPE_128x128) >= shader_core_count * 2) {
return CONV_SHAPE_128x128;