Compare commits

...

12 Commits

Author SHA1 Message Date
Georgi Gerganov
292f8e231c model-conversion : cast logits to float32 2025-12-13 22:24:21 +02:00
Jeff Bolz
5266379bca llama_context: synchronize before reallocating output buffer (#17974) 2025-12-13 09:19:51 -06:00
Xuan-Son Nguyen
4d5ae24c0a arg: fix common_params_parse not accepting negated arg (#17991) 2025-12-13 12:53:37 +01:00
Gustavo Rocha Dias
66ba51252e cmake: correct scope - link ws2_32 for MinGW/w64devkit builds in cpp-httplib (#17972)
* fix - w64devkit build

* fix - w64devkit build private scope
2025-12-13 12:46:36 +01:00
Jeff Bolz
36255a2268 vulkan: support get_rows for i32 (#17941) 2025-12-13 10:12:53 +01:00
Jeff Bolz
3229a23fa6 vulkan: support GGML_OP_DIAG (#17893) 2025-12-13 10:07:49 +01:00
Jeff Bolz
303f8615e9 vulkan: Multi-pass softmax for large number of cols (#17892)
When the number of cols is large, split each row across multiple workgroups.
There are three phases that communicate partial results through temp buffers:
(1) compute max partials
(2) take max of partials, compute sum(exp(x-max)) partials
(3) sum partials, compute scaled result
2025-12-13 10:04:29 +01:00
Georgi Gerganov
3c6391e748 speculative-simple : free batch on exit (#17985) 2025-12-13 09:48:34 +02:00
Sigbjørn Skjæret
8e4d678528 common : skip model validation when --completion-bash is requested (#17975) 2025-12-13 08:40:50 +01:00
Jeff Bolz
07a10c1090 vulkan: Allow non-pow2 n_experts in topk_moe (#17872) 2025-12-13 08:40:04 +01:00
Sigbjørn Skjæret
2bc94e7928 add llama-completion to completion-bash executables (#17976) 2025-12-13 08:35:50 +01:00
Daniel Bevenius
fd1085ffb7 model-conversion : use CONVERTED_MODEL value for converted model [no ci] (#17984)
* model-conversion : use CONVERTED_MODEL value for converted model [no ci]

This commit updates the model verification scripts to use the
CONVERTED_MODEL environment variable instead of using the MODEL_PATH
(the original model path) as the basis for the converted model file
name.

The motivation for this that currently if the converted model file name
differs from the original model directory/name the verification scripts
will look for the wrong .bin files that were generating when running the
models.
For example, the following steps were not possible:
```console
(venv) $ huggingface-cli download google/gemma-3-270m-it --local-dir ggml-org/gemma-3-270m
(venv) $ python3 convert_hf_to_gguf.py ggml-org/gemma-3-270m --outfile test-bf16.gguf --outtype bf16
(venv) $ cd examples/model-conversion/
(venv) $ export MODEL_PATH=../../ggml-org/gemma-3-270m
(venv) $ export CONVERTED_MODEL=../../test-bf16.gguf
(venv) $ make causal-verify-logits
...
Data saved to data/llamacpp-test-bf16.bin
Data saved to data/llamacpp-test-bf16.txt
Error: llama.cpp logits file not found: data/llamacpp-gemma-3-270m.bin
Please run scripts/run-converted-model.sh first to generate this file.
make: *** [Makefile:62: causal-verify-logits] Error 1
```

With the changes in this commit, the above steps will now work as
expected.
2025-12-13 08:34:26 +01:00
23 changed files with 481 additions and 41 deletions

View File

@@ -504,7 +504,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
// model is required (except for server)
// TODO @ngxson : maybe show a list of available models in CLI in this case
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage) {
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage && !params.completion) {
throw std::invalid_argument("error: --model is required\n");
}
@@ -639,6 +639,7 @@ static void common_params_print_completion(common_params_context & ctx_arg) {
"llama-batched-bench",
"llama-bench",
"llama-cli",
"llama-completion",
"llama-convert-llama2c-to-ggml",
"llama-cvector-generator",
"llama-embedding",
@@ -723,7 +724,7 @@ static void add_rpc_devices(const std::string & servers) {
}
}
bool common_params_parse(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map) {
bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map) {
common_params dummy_params;
common_params_context ctx_arg = common_params_parser_init(dummy_params, ex, nullptr);
@@ -732,6 +733,9 @@ bool common_params_parse(int argc, char ** argv, llama_example ex, std::map<comm
for (const auto & arg : opt.args) {
arg_to_options[arg] = &opt;
}
for (const auto & arg : opt.args_neg) {
arg_to_options[arg] = &opt;
}
}
// TODO @ngxson : find a way to deduplicate this code

View File

@@ -115,7 +115,7 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
// parse input arguments from CLI into a map
// TODO: support repeated args in the future
bool common_params_parse(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map);
bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map);
// initialize argument parser context - used by test-arg-parser and preset
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);

View File

@@ -1,10 +1,13 @@
#!/usr/bin/env python3
import numpy as np
import sys
import os
import numpy as np
from pathlib import Path
# Add utils directory to path for direct script execution
sys.path.insert(0, str(Path(__file__).parent.parent / "utils"))
from common import get_model_name_from_env_path # type: ignore[import-not-found]
def quick_logits_check(pytorch_file, llamacpp_file):
"""Lightweight sanity check before NMSE"""
@@ -35,20 +38,13 @@ def quick_logits_check(pytorch_file, llamacpp_file):
return True
def main():
model_path = os.getenv('MODEL_PATH')
if not model_path:
print("Error: MODEL_PATH environment variable not set")
sys.exit(1)
if not os.path.exists(model_path):
print(f"Error: Model file not found: {model_path}")
sys.exit(1)
model_name = os.path.basename(model_path)
model_name = get_model_name_from_env_path('MODEL_PATH')
data_dir = Path("data")
pytorch_file = data_dir / f"pytorch-{model_name}.bin"
llamacpp_file = data_dir / f"llamacpp-{model_name}.bin"
llamacpp_model_name = get_model_name_from_env_path('CONVERTED_MODEL')
print(f"Using converted model: {llamacpp_model_name}")
llamacpp_file = data_dir / f"llamacpp-{llamacpp_model_name}.bin"
if not pytorch_file.exists():
print(f"Error: PyTorch logits file not found: {pytorch_file}")

View File

@@ -200,7 +200,7 @@ with torch.no_grad():
logits = outputs.logits
# Extract logits for the last token (next token prediction)
last_logits = logits[0, -1, :].cpu().numpy()
last_logits = logits[0, -1, :].float().cpu().numpy()
print(f"Logits shape: {logits.shape}")
print(f"Last token logits shape: {last_logits.shape}")

View File

@@ -5,6 +5,7 @@ import sys
import os
import argparse
from pathlib import Path
from common import get_model_name_from_env_path # type: ignore[import-not-found]
def calculate_nmse(reference, test):
mse = np.mean((test - reference) ** 2)
@@ -67,11 +68,13 @@ def main():
parser.add_argument('-m', '--model-path', required=True, help='Path to the model directory')
args = parser.parse_args()
model_name = os.path.basename(args.model_path)
model_name = get_model_name_from_env_path('MODEL_PATH')
data_dir = Path("data")
pytorch_file = data_dir / f"pytorch-{model_name}.bin"
llamacpp_file = data_dir / f"llamacpp-{model_name}.bin"
llamacpp_model_name = get_model_name_from_env_path('CONVERTED_MODEL')
llamacpp_file = data_dir / f"llamacpp-{llamacpp_model_name}.bin"
print(f"Model name: {model_name}")
print(f"PyTorch logits file: {pytorch_file}")

View File

@@ -0,0 +1,20 @@
#!/usr/bin/env python3
import os
import sys
def get_model_name_from_env_path(env_path_name):
model_path = os.getenv(env_path_name)
if not model_path:
print(f"Error: {env_path_name} environment variable not set")
sys.exit(1)
if not os.path.exists(model_path):
print(f"Error: Model file not found: {model_path}")
sys.exit(1)
name = os.path.basename(os.path.normpath(model_path))
if name.endswith(".gguf"):
name = name[:-5]
return name

View File

@@ -255,6 +255,8 @@ int main(int argc, char ** argv) {
LOG_INF("target:\n\n");
common_perf_print(ctx_tgt, smpl);
llama_batch_free(batch_tgt);
common_sampler_free(smpl);
common_speculative_free(spec);

View File

@@ -659,6 +659,7 @@ struct vk_device_struct {
vk_pipeline pipeline_cos_f32;
vk_pipeline pipeline_log[2];
vk_pipeline pipeline_tri[2];
vk_pipeline pipeline_diag[2];
vk_pipeline pipeline_clamp_f32;
vk_pipeline pipeline_pad_f32;
vk_pipeline pipeline_roll_f32;
@@ -722,6 +723,11 @@ struct vk_device_struct {
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
vk_pipeline pipeline_soft_max_back_f32;
vk_pipeline pipeline_soft_max_large1_f32, pipeline_soft_max_large1_f32_f16;
vk_pipeline pipeline_soft_max_large2_f32, pipeline_soft_max_large2_f32_f16;
vk_pipeline pipeline_soft_max_large3_f32, pipeline_soft_max_large3_f32_f16;
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16;
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
@@ -757,7 +763,8 @@ struct vk_device_struct {
vk_pipeline pipeline_flash_attn_split_k_reduce;
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT];
// [2] is for whether to take n_experts from spec constant (0) or push constant (1)
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2];
std::vector<vk_pipeline_ref> all_pipelines;
@@ -1149,6 +1156,7 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
struct vk_op_topk_moe_push_constants {
uint32_t n_rows;
uint32_t n_experts_push;
uint32_t n_expert_used;
float clamp_min;
float clamp_max;
@@ -3730,6 +3738,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
@@ -3917,6 +3926,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
@@ -3996,6 +4008,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32, "soft_max_large1_f32", soft_max_large1_f32_len, soft_max_large1_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32, "soft_max_large2_f32", soft_max_large2_f32_len, soft_max_large2_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32, "soft_max_large3_f32", soft_max_large3_f32_len, soft_max_large3_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32_f16, "soft_max_large1_f32_f16", soft_max_large1_f32_f16_len, soft_max_large1_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32_f16, "soft_max_large2_f32_f16", soft_max_large2_f32_f16_len, soft_max_large2_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32_f16, "soft_max_large3_f32_f16", soft_max_large3_f32_f16_len, soft_max_large3_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -4204,10 +4223,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true, device->subgroup_size);
for (uint32_t use_push = 0; use_push < 2; ++use_push) {
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0, use_push}, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0, use_push}, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1, use_push}, 1, true, true, device->subgroup_size);
}
}
for (auto &c : compiles) {
@@ -8274,6 +8295,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
switch (op) {
case GGML_OP_GET_ROWS:
GGML_ASSERT(src1->type == GGML_TYPE_I32);
if (src0->type == GGML_TYPE_I32) {
// i32 src only supports i32 result
GGML_ASSERT(dst->type == GGML_TYPE_I32);
return ctx->device->pipeline_get_rows[src0->type];
}
if (dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_get_rows[src0->type];
}
@@ -8400,6 +8426,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_DIAG:
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_diag[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_CLAMP:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_clamp_f32;
@@ -8554,7 +8586,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
GGML_ASSERT(idx < num_topk_moe_pipelines);
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
return ctx->device->pipeline_topk_moe[idx][mode];
// use n_experts from push constant if it's not equal to the power of two spec constant
bool use_push = dst->ne[0] != (1u << idx);
return ctx->device->pipeline_topk_moe[idx][mode][use_push];
}
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
@@ -9091,6 +9125,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_COS:
case GGML_OP_LOG:
case GGML_OP_TRI:
case GGML_OP_DIAG:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_ROLL:
@@ -9778,6 +9813,12 @@ static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p));
}
static void ggml_vk_diag(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG, std::move(p));
}
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = ggml_get_op_params_f32(dst, 0);
@@ -10111,7 +10152,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, {
vk_op_soft_max_push_constants pc {
ncols,
src1 != nullptr ? nrows_y : (uint32_t)0,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
@@ -10122,7 +10163,55 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
n_head_log2,
nrows_x,
src2 != nullptr
});
};
if (ncols <= 16384) {
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, std::move(pc));
} else {
vk_subbuffer buf_a = ggml_vk_tensor_subbuffer(ctx, src0);
vk_subbuffer buf_b = src1 ? ggml_vk_tensor_subbuffer(ctx, src1) : buf_a;
vk_subbuffer buf_c = src2 ? ggml_vk_tensor_subbuffer(ctx, src2) : buf_a;
vk_subbuffer buf_d = ggml_vk_tensor_subbuffer(ctx, dst);
uint32_t elems_per_wg = 128 * 4;
uint32_t num_wgs = CEIL_DIV(ncols, elems_per_wg);
size_t tmp_size = num_wgs * nrows_x * sizeof(float);
if (ctx->prealloc_size_x < tmp_size) {
ctx->prealloc_size_x = tmp_size;
ggml_vk_preallocate_buffers(ctx, subctx);
}
if (ctx->prealloc_size_y < tmp_size) {
ctx->prealloc_size_y = tmp_size;
ggml_vk_preallocate_buffers(ctx, subctx);
}
if (ctx->prealloc_x_need_sync || ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
vk_subbuffer buf_x = { ctx->prealloc_x, 0, tmp_size };
vk_subbuffer buf_y = { ctx->prealloc_y, 0, tmp_size };
std::array<uint32_t, 3> elements = { num_wgs, nrows_x, 1 };
vk_pipeline pipeline1 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large1_f32_f16 : ctx->device->pipeline_soft_max_large1_f32;
vk_pipeline pipeline2 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large2_f32_f16 : ctx->device->pipeline_soft_max_large2_f32;
vk_pipeline pipeline3 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large3_f32_f16 : ctx->device->pipeline_soft_max_large3_f32;
ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
ggml_pipeline_request_descriptor_sets(ctx, pipeline3, 1);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
ggml_vk_sync_buffers(ctx, subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
ggml_vk_sync_buffers(ctx, subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline3, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
ctx->prealloc_x_need_sync = true;
ctx->prealloc_y_need_sync = true;
}
}
static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -10158,6 +10247,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
vk_op_topk_moe_push_constants pc {};
pc.n_rows = n_rows;
pc.n_experts_push = n_experts;
pc.n_expert_used = n_expert_used;
if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
@@ -11857,6 +11947,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_TRI:
ggml_vk_tri(ctx, compute_ctx, src0, node);
break;
case GGML_OP_DIAG:
ggml_vk_diag(ctx, compute_ctx, src0, node);
break;
case GGML_OP_CLAMP:
ggml_vk_clamp(ctx, compute_ctx, src0, node);
@@ -12832,8 +12926,7 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
}
const int n_expert = softmax->ne[0];
// n_expert must be a power of 2
if (!is_pow2(n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) {
if (n_expert > (1 << (num_topk_moe_pipelines-1))) {
return false;
}
@@ -13877,6 +13970,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
case GGML_TYPE_I32:
return true;
default:
return false;
@@ -14001,6 +14095,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_LOG:
case GGML_OP_TRI:
case GGML_OP_DIAG:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->type == op->src[0]->type;
case GGML_OP_ARGSORT:
@@ -14591,6 +14686,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_TRI) {
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
} else if (tensor->op == GGML_OP_DIAG) {
tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_CLAMP) {
const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);

View File

@@ -0,0 +1,29 @@
#version 450
#include "rte.glsl"
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
const uint i12_offset = i12*p.ne11*p.ne10;
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
if (i10 == i11) {
const float val = float(data_a[get_aoffset() + i13*p.nb03 + i12*p.nb02 + 0*p.nb01 + i10*p.nb00]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val);
} else {
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0);
}
}

View File

@@ -26,9 +26,9 @@ void main() {
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
#if defined(DATA_A_BF16)
FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
TEMP_TYPE v = TEMP_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
#else
FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
TEMP_TYPE v = TEMP_TYPE(data_a[a_offset + i00]);
#endif
#ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[d_offset + i00] = D_TYPE(v);

View File

@@ -0,0 +1,62 @@
#version 450
#include "soft_max_large_common.glsl"
void main() {
const uint tid = gl_LocalInvocationID.x;
const uint rowx = gl_WorkGroupID.y;
const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;
const uint32_t i03 = rowx / (p.ne01 * p.ne02);
const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
const uint32_t i01 = rowx % p.ne01;
uint rowy_start = 0;
if (p.KY > 0) {
rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
}
if (rowx >= p.nrows_x) {
return;
}
float slope = get_slope(rowx);
// Find max
FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
[[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
const uint col = col0 + tid;
FLOAT_TYPE a = FLOAT_TYPE(0);
if (col < p.KX) {
a = data_a[rowx * p.KX + col];
}
FLOAT_TYPE b = FLOAT_TYPE(0);
if (p.KY > 0 && col < p.KX) {
b = data_b[rowy_start + col];
}
FLOAT_TYPE v = a * p.scale + slope * b;
if (col < p.KX) {
max_val = max(max_val, v);
}
}
// reduce across the workgroup
vals[tid] = max_val;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
vals[tid] = max(vals[tid], vals[tid + s]);
}
barrier();
}
if (tid == 0) {
max_val = vals[0];
data_m[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = max_val;
}
}

View File

@@ -0,0 +1,79 @@
#version 450
#include "soft_max_large_common.glsl"
void main() {
const uint tid = gl_LocalInvocationID.x;
const uint rowx = gl_WorkGroupID.y;
const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;
const uint32_t i03 = rowx / (p.ne01 * p.ne02);
const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
const uint32_t i01 = rowx % p.ne01;
uint rowy_start = 0;
if (p.KY > 0) {
rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
}
if (rowx >= p.nrows_x) {
return;
}
float slope = get_slope(rowx);
// Find max
FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
[[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) {
if (i + tid < gl_NumWorkGroups.x) {
max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]);
}
}
// reduce across the workgroup
vals[tid] = max_val;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
vals[tid] = max(max_val, vals[tid + s]);
}
barrier();
}
max_val = vals[0];
barrier();
FLOAT_TYPE sum = FLOAT_TYPE(0.0f);
// Compute sum{exp(x - max)}
[[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
const uint col = col0 + tid;
if (col >= p.KX) {
break;
}
// compute exp(a*scale+b*slope), add it to sum
const uint i = rowx * p.KX + col;
FLOAT_TYPE val;
val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val);
sum += val;
data_d[i] = D_TYPE(val);
}
// reduce across the workgroup
vals[tid] = sum;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
vals[tid] += vals[tid + s];
}
barrier();
}
if (tid == 0) {
sum = vals[0];
data_s[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = sum;
}
}

View File

@@ -0,0 +1,65 @@
#version 450
#include "soft_max_large_common.glsl"
shared FLOAT_TYPE sumsh[BLOCK_SIZE];
void main() {
const uint tid = gl_LocalInvocationID.x;
const uint rowx = gl_WorkGroupID.y;
const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;
const uint32_t i03 = rowx / (p.ne01 * p.ne02);
const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
const uint32_t i01 = rowx % p.ne01;
uint rowy_start = 0;
if (p.KY > 0) {
rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
}
if (rowx >= p.nrows_x) {
return;
}
FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
FLOAT_TYPE sum = FLOAT_TYPE(0.0f);
[[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) {
if (i + tid < gl_NumWorkGroups.x) {
max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]);
sum += data_s[rowx * gl_NumWorkGroups.x + i + tid];
}
}
// reduce across the workgroup
vals[tid] = max_val;
sumsh[tid] = sum;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
vals[tid] = max(max_val, vals[tid + s]);
sumsh[tid] += sumsh[tid + s];
}
barrier();
}
max_val = vals[0];
sum = sumsh[0];
if (p.has_sinks != 0) {
sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val));
}
FLOAT_TYPE rcpdivisor = 1.0/sum;
[[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
const uint col = col0 + tid;
if (col >= p.KX) {
continue;
}
data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor);
}
}

View File

@@ -0,0 +1,53 @@
#extension GL_EXT_control_flow_attributes : enable
layout (push_constant) uniform parameter
{
uint KX;
uint KY;
uint ne00;
uint ne01;
uint ne02;
uint ne12;
uint ne13;
uint nb11;
uint nb12;
uint nb13;
float scale;
float max_bias;
float m0;
float m1;
uint n_head_log2;
uint nrows_x;
uint has_sinks;
} p;
#include "types.glsl"
layout(constant_id = 0) const uint BLOCK_SIZE = 128;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 1) const uint num_iters = 4;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
layout (binding = 2) readonly buffer Z {float data_c[];};
layout (binding = 3) buffer D {D_TYPE data_d[];};
layout (binding = 4) buffer M {float data_m[];};
layout (binding = 5) buffer S {float data_s[];};
shared FLOAT_TYPE vals[BLOCK_SIZE];
float get_slope(uint rowx) {
float slope = 1.0f;
// ALiBi
if (p.max_bias > 0.0f) {
const uint h = (rowx / p.ne01) % p.ne02; // head index
const float base = h < p.n_head_log2 ? p.m0 : p.m1;
const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
slope = pow(base, exp);
}
return slope;
}

View File

@@ -10,6 +10,7 @@
layout (push_constant) uniform parameter
{
uint n_rows;
uint n_experts_push;
uint n_expert_used;
float clamp_min;
float clamp_max;
@@ -18,11 +19,16 @@ layout (push_constant) uniform parameter
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint n_experts = 512;
layout(constant_id = 1) const uint n_experts_spec = 512;
layout(constant_id = 2) const bool with_norm = true;
layout(constant_id = 3) const bool late_softmax = false;
layout(constant_id = 4) const bool nexperts_use_push = false;
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE);
layout (binding = 0, std430) readonly buffer Logits {float logits[];};
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
@@ -94,7 +100,7 @@ void main() {
}
if (!late_softmax) {
softmax_warp_inplace(wt, n_experts, lane, false);
softmax_warp_inplace(wt, n_experts, lane, nexperts_use_push);
}
// at this point, each thread holds a portion of softmax,

View File

@@ -704,13 +704,15 @@ void process_shaders() {
shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
if (tname == "f16") {
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
} else {
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
}
string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
}
string_to_spv("get_rows_i32", "get_rows.comp", {{"TEMP_TYPE", "uint"}, {"A_TYPE", "uint"}, {"B_TYPE", "int"}, {"D_TYPE", "uint"}});
string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
@@ -854,6 +856,8 @@ void process_shaders() {
string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("diag_f16", "diag.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("diag_f32", "diag.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
@@ -899,6 +903,13 @@ void process_shaders() {
string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("soft_max_large1_f32", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("soft_max_large2_f32", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("soft_max_large3_f32", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("soft_max_large1_f32_f16", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
string_to_spv("soft_max_large2_f32_f16", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
string_to_spv("soft_max_large3_f32_f16", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});

View File

@@ -1,5 +1,5 @@
{
"extraPaths": ["gguf-py"],
"extraPaths": ["gguf-py", "examples/model-conversion/scripts"],
"pythonVersion": "3.9",
"pythonPlatform": "All",
"reportUnusedImport": "warning",

View File

@@ -1318,6 +1318,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
#endif
synchronize();
buf_output = nullptr;
logits = nullptr;
embd = nullptr;

View File

@@ -72,6 +72,10 @@ int main(void) {
argv = {"binary_name", "--draft", "123"};
assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_EMBEDDING));
// negated arg
argv = {"binary_name", "--no-mmap"};
assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
printf("test-arg-parser: test valid usage\n\n");

View File

@@ -7652,6 +7652,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
for (float max_bias : {0.0f, 8.0f}) {
for (float scale : {1.0f, 0.1f}) {
for (int64_t ne0 : {16, 1024}) {
@@ -7971,8 +7974,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (bool with_norm : {false, true}) {
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm));
test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm));
test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm));
}
test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));

View File

@@ -171,7 +171,7 @@ server_presets::server_presets(int argc, char ** argv, common_params & base_para
}
// read base args from router's argv
common_params_parse(argc, argv, LLAMA_EXAMPLE_SERVER, base_args);
common_params_to_map(argc, argv, LLAMA_EXAMPLE_SERVER, base_args);
// remove any router-controlled args from base_args
for (const auto & cargs : control_args) {

View File

@@ -11,8 +11,9 @@ endif()
target_link_libraries (${TARGET} PRIVATE Threads::Threads)
if (WIN32 AND NOT MSVC)
target_link_libraries(${TARGET} PUBLIC ws2_32)
target_link_libraries(${TARGET} PRIVATE ws2_32)
endif()
target_compile_features(${TARGET} PRIVATE cxx_std_17)
target_compile_definitions(${TARGET} PRIVATE