Compare commits

...

2 Commits
b9012 ... b9014

Author SHA1 Message Date
Chen Yuan
d4b0c22f9e ggml-webgpu: add layer norm ops (#22406)
* shader(norm): add layer norm ops

* shader(norm): stablize floating point computation with Kahan summation and handle mixed types

* shader(norm): remove the non-contiguous strides

* shader(norm): use the original implementation rather than the kahan summation
2026-05-03 20:52:53 -07:00
Aldehir Rojas
e48034dfc9 common : determine generation prompt using longest common prefix (#22657) 2026-05-04 00:18:23 +02:00
4 changed files with 134 additions and 42 deletions

View File

@@ -2116,22 +2116,38 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
return std::nullopt;
}
static std::string common_chat_templates_generation_prompt(const common_chat_template & tmpl, const autoparser::generation_params & inputs) {
autoparser::generation_params params = inputs;
params.add_generation_prompt = false;
std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
params.add_generation_prompt = true;
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
size_t prefix_len = 0;
size_t min_size = std::min(no_gen_prompt.size(), gen_prompt.size());
while (prefix_len < min_size && no_gen_prompt[prefix_len] == gen_prompt[prefix_len]) {
prefix_len++;
}
return gen_prompt.substr(prefix_len);
}
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs) {
autoparser::generation_params params;
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
const auto & tmpl =
params.tools.is_array() && tmpls->template_tool_use ? *tmpls->template_tool_use : *tmpls->template_default;
const auto & src = tmpl.source();
const auto & caps = tmpl.original_caps();
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
params.tool_choice = inputs.tool_choice;
params.reasoning_format = inputs.reasoning_format;
params.enable_thinking = inputs.enable_thinking;
params.grammar = inputs.grammar;
params.now = inputs.now;
params.add_bos = tmpls->add_bos;
params.add_eos = tmpls->add_eos;
const auto & src = tmpl.source();
const auto & caps = tmpl.original_caps();
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
params.tool_choice = inputs.tool_choice;
params.reasoning_format = inputs.reasoning_format;
params.enable_thinking = inputs.enable_thinking;
params.grammar = inputs.grammar;
params.now = inputs.now;
params.add_generation_prompt = inputs.add_generation_prompt;
params.add_bos = tmpls->add_bos;
params.add_eos = tmpls->add_eos;
if (src.find("<|channel|>") == std::string::npos) {
// map developer to system for all models except for GPT-OSS
@@ -2153,14 +2169,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
workaround::func_args_not_string(params.messages);
}
params.add_generation_prompt = false;
std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
params.add_generation_prompt = true;
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
params.generation_prompt = diff.right + diff.suffix;
params.add_generation_prompt = inputs.add_generation_prompt;
params.generation_prompt = common_chat_templates_generation_prompt(tmpl, params);
params.extra_context = common_chat_extra_context();
for (auto el : inputs.chat_template_kwargs) {

View File

@@ -228,11 +228,13 @@ struct ggml_webgpu_get_rows_pipeline_key_hash {
/** Row Norm **/
struct ggml_webgpu_row_norm_pipeline_key {
ggml_op op;
bool inplace;
ggml_op op;
ggml_type src_type;
ggml_type dst_type;
bool inplace;
bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const {
return op == other.op && inplace == other.inplace;
return op == other.op && src_type == other.src_type && dst_type == other.dst_type && inplace == other.inplace;
}
};
@@ -240,6 +242,8 @@ struct ggml_webgpu_row_norm_pipeline_key_hash {
size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.op);
ggml_webgpu_hash_combine(seed, key.src_type);
ggml_webgpu_hash_combine(seed, key.dst_type);
ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
@@ -1097,6 +1101,8 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_row_norm_pipeline_key key = {};
key.op = context.dst->op;
key.src_type = context.src0->type;
key.dst_type = context.dst->type;
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
auto it = row_norm_pipelines.find(key);
@@ -1111,6 +1117,10 @@ class ggml_webgpu_shader_lib {
defines.push_back("RMS_NORM");
variant = "rms_norm";
break;
case GGML_OP_NORM:
defines.push_back("NORM");
variant = "norm";
break;
case GGML_OP_L2_NORM:
defines.push_back("L2_NORM");
variant = "l2_norm";
@@ -1124,6 +1134,22 @@ class ggml_webgpu_shader_lib {
variant += "_inplace";
}
if (key.src_type == GGML_TYPE_F32) {
defines.push_back("SRC_F32");
variant += "_src_f32";
} else if (key.src_type == GGML_TYPE_F16) {
defines.push_back("SRC_F16");
variant += "_src_f16";
}
if (key.dst_type == GGML_TYPE_F32) {
defines.push_back("DST_F32");
variant += "_dst_f32";
} else if (key.dst_type == GGML_TYPE_F16) {
defines.push_back("DST_F16");
variant += "_dst_f16";
}
const uint32_t row_norm_wg_size = 128u;
uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size);
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));

View File

@@ -2927,6 +2927,7 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
} else {
return ggml_webgpu_row_norm(ctx, src0, node);
}
case GGML_OP_NORM:
case GGML_OP_L2_NORM:
return ggml_webgpu_row_norm(ctx, src0, node);
case GGML_OP_ROPE:
@@ -4071,6 +4072,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
break;
}
case GGML_OP_RMS_NORM:
case GGML_OP_NORM:
case GGML_OP_L2_NORM:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;

View File

@@ -1,20 +1,17 @@
#ifdef INPLACE
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
src[dst_offset] = scale * src[src_offset];
}
#if defined(SRC_F16) || defined(DST_F16)
enable f16;
#endif
@group(0) @binding(1)
var<uniform> params: Params;
#ifdef SRC_F16
#define SRC_TYPE f16
#else
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
dst[dst_offset] = scale * src[src_offset];
}
#define SRC_TYPE f32
#endif
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
#ifdef DST_F16
#define DST_TYPE f16
#else
#define DST_TYPE f32
#endif
struct Params {
@@ -40,9 +37,20 @@ struct Params {
};
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
var<storage, read_write> src: array<SRC_TYPE>;
var<workgroup> scratch: array<f32, WG_SIZE>;
#ifdef INPLACE
@group(0) @binding(1)
var<uniform> params: Params;
#else
@group(0) @binding(1)
var<storage, read_write> dst: array<DST_TYPE>;
@group(0) @binding(2)
var<uniform> params: Params;
#endif
var<workgroup> scratch: array<f32, WG_SIZE * 2u>;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@@ -65,34 +73,81 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
if (col >= params.ne0) {
break;
}
sum += pow(src[i_src_row + col], 2.0);
let v = f32(src[i_src_row + col]);
#ifdef NORM
sum += v;
#else
sum += v * v;
#endif
col += WG_SIZE;
}
scratch[lid.x] = sum;
workgroupBarrier();
var offset: u32 = WG_SIZE / 2;
var offset: u32 = WG_SIZE / 2u;
while (offset > 0) {
if (lid.x < offset) {
scratch[lid.x] += scratch[lid.x + offset];
}
offset = offset / 2;
offset /= 2u;
workgroupBarrier();
}
sum = scratch[0];
#ifdef RMS_NORM
#ifdef NORM
let mean = sum / f32(params.ne0);
var sq_sum = 0.0f;
col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {
break;
}
let v = f32(src[i_src_row + col]);
let d = v - mean;
sq_sum += d * d;
col += WG_SIZE;
}
workgroupBarrier();
scratch[lid.x] = sq_sum;
workgroupBarrier();
offset = WG_SIZE / 2u;
while (offset > 0) {
if (lid.x < offset) {
scratch[lid.x] += scratch[lid.x + offset];
}
offset /= 2u;
workgroupBarrier();
}
let variance = scratch[0] / f32(params.ne0);
let scale = 1.0 / sqrt(variance + params.eps);
#elif defined(RMS_NORM)
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
#elif defined(L2_NORM)
let scale = 1.0/max(sqrt(sum), params.eps);
#endif
#ifdef NORM
let mean_val = mean;
#else
let mean_val = 0.0f;
#endif
col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {
break;
}
update(i_src_row + col, i_dst_row + col, scale);
let i_src = i_src_row + col;
let i_dst = i_dst_row + col;
let v = src[i_src];
#ifdef INPLACE
src[i_dst] = scale * (v - mean_val);
#else
dst[i_dst] = scale * (v - mean_val);
#endif
col += WG_SIZE;
}
}