Compare commits

...

18 Commits

Author SHA1 Message Date
Georgi Gerganov
5372fc6461 wip 2026-02-10 23:44:42 +02:00
Georgi Gerganov
0cc02542a8 wip 2026-02-10 23:36:46 +02:00
Georgi Gerganov
08358235a3 wip 2026-02-10 22:46:09 +02:00
Georgi Gerganov
bd7c16f0a4 metal : extend l2_norm support for non-cont src0 2026-02-10 22:45:59 +02:00
Georgi Gerganov
6bd21ebb29 wip 2026-02-10 22:07:56 +02:00
Georgi Gerganov
862d720ad1 wip 2026-02-10 21:53:34 +02:00
Georgi Gerganov
89dd9f6a10 wip 2026-02-10 21:24:39 +02:00
Georgi Gerganov
e480e383fd wip 2026-02-10 20:50:47 +02:00
Georgi Gerganov
1c312dc758 wip 2026-02-10 20:33:38 +02:00
Georgi Gerganov
835a949286 wip 2026-02-10 20:02:06 +02:00
Georgi Gerganov
b1264663c2 wip 2026-02-10 19:36:49 +02:00
Georgi Gerganov
e2c0463eab tests : simplify 2026-02-10 18:32:28 +02:00
Georgi Gerganov
c82bc9c030 cont : s0 is always 1 2026-02-10 18:32:28 +02:00
Georgi Gerganov
029c30fda4 cont : extend bin support 2026-02-10 18:32:28 +02:00
Georgi Gerganov
0b0bfb20f4 tests : extend bin bcast for permuted src1 2026-02-10 18:32:28 +02:00
Georgi Gerganov
ff77be289e metal : consolidate unary ops 2026-02-10 18:32:27 +02:00
Georgi Gerganov
404d0c8e80 cont 2026-02-10 11:22:50 +02:00
Georgi Gerganov
25dad910ab models : optimizing qwen3next graph 2026-02-10 10:43:41 +02:00
12 changed files with 643 additions and 900 deletions

View File

@@ -59,11 +59,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
GGML_ASSERT(nb00 == sizeof(src0_t));
const auto [ir0, ir1] = get_thread_range(params, src0);
const bool is_src1_contiguous = (nb10 == sizeof(src1_t));
if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous
GGML_ASSERT(ggml_are_same_shape(src0, src1));
}
const bool is_src1_contiguous_rows = ggml_is_contiguous_rows(src1);
#ifdef GGML_USE_ACCELERATE
vDSP_fn_t vDSP_op = nullptr;
@@ -94,7 +90,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
if (is_src1_contiguous) {
if (is_src1_contiguous_rows) {
// src1 is broadcastable across src0 and dst in i1, i2, i3
const int64_t nr0 = ne00 / ne10;

View File

@@ -39,13 +39,16 @@ static __global__ void k_bin_bcast(const src0_t * src0,
const uint3 ne11,
const uint3 ne12,
const uint3 ne13,
/*int s0, */ const int s1,
/*const int s0,*/
const int s1,
const int s2,
const int s3,
/*int s00,*/ const int s01,
const int s00,
const int s01,
const int s02,
const int s03,
/*int s10,*/ const int s11,
const int s10,
const int s11,
const int s12,
const int s13,
src1_ptrs... src1s) {
@@ -72,11 +75,11 @@ static __global__ void k_bin_bcast(const src0_t * src0,
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
const uint32_t i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0] : 0.0f;
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
} else {
result = bin_op(result, (float)src1[i_src1 + i10]);
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
}
dst_row[i0] = (dst_t) result;
@@ -101,13 +104,16 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
const uint3 ne11,
const uint3 ne12,
const uint3 ne13,
/*int s0, */ const int s1,
/*const int s0,*/
const int s1,
const int s2,
const int s3,
/*int s00,*/ const int s01,
const int s00,
const int s01,
const int s02,
const int s03,
/*int s10,*/ const int s11,
const int s10,
const int s11,
const int s12,
const int s13,
src1_ptrs... src1s) {
@@ -135,11 +141,11 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
const int i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0] : 0.0f;
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
} else {
result = bin_op(result, (float)src1[i_src1 + i10]);
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
}
dst_row[i0] = (dst_t) result;
@@ -179,7 +185,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
cnb[3] *= cne[3];
};
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
for (int i = 0; i < 4; i++) {
if (nr[i] != 1) {
break;
@@ -221,7 +227,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
size_t nb12 = cnb1[2];
size_t nb13 = cnb1[3];
size_t s0 = nb0 / sizeof(dst_t);
//size_t s0 = nb0 / sizeof(dst_t);
size_t s1 = nb1 / sizeof(dst_t);
size_t s2 = nb2 / sizeof(dst_t);
size_t s3 = nb3 / sizeof(dst_t);
@@ -251,10 +257,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
GGML_ASSERT(s0 == 1);
GGML_ASSERT(s00 == 1);
GGML_ASSERT(s10 == 1);
const int block_size = 128;
int64_t hne0 = std::max(ne0 / 2LL, 1LL);
@@ -284,31 +286,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13);
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13);
}
} else {
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
if constexpr (sizeof...(I) > 0) {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
/*s0,*/ s1, s2, s3,
s00 ,s01, s02, s03,
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13);
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13);
}
}
}

View File

@@ -4542,6 +4542,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_TRUNC:
// TODO: should become:
//return ggml_is_contiguous_rows(op->src[0]);
return ggml_is_contiguous(op->src[0]);
default:
return false;

View File

@@ -267,6 +267,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
case GGML_OP_SUM_ROWS:
case GGML_OP_MUL:
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
case GGML_OP_GLU:
case GGML_OP_SCALE:

View File

@@ -212,61 +212,69 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_meta
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
char base[256];
char name[256];
const int64_t n = ggml_nelements(op);
int op_num = -1;
const char * op_str = "undefined";
switch (op->op) {
case GGML_OP_SCALE: op_str = "scale"; break;
case GGML_OP_FILL: op_str = "fill"; break;
case GGML_OP_CLAMP: op_str = "clamp"; break;
case GGML_OP_SQR: op_str = "sqr"; break;
case GGML_OP_SQRT: op_str = "sqrt"; break;
case GGML_OP_SIN: op_str = "sin"; break;
case GGML_OP_COS: op_str = "cos"; break;
case GGML_OP_LOG: op_str = "log"; break;
case GGML_OP_LEAKY_RELU: op_str = "leaky_relu"; break;
case GGML_OP_SCALE: op_num = OP_UNARY_NUM_SCALE; break;
case GGML_OP_FILL: op_num = OP_UNARY_NUM_FILL; break;
case GGML_OP_CLAMP: op_num = OP_UNARY_NUM_CLAMP; break;
case GGML_OP_SQR: op_num = OP_UNARY_NUM_SQR; break;
case GGML_OP_SQRT: op_num = OP_UNARY_NUM_SQRT; break;
case GGML_OP_SIN: op_num = OP_UNARY_NUM_SIN; break;
case GGML_OP_COS: op_num = OP_UNARY_NUM_COS; break;
case GGML_OP_LOG: op_num = OP_UNARY_NUM_LOG; break;
case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_TANH: op_str = "tanh"; break;
case GGML_UNARY_OP_RELU: op_str = "relu"; break;
case GGML_UNARY_OP_SIGMOID: op_str = "sigmoid"; break;
case GGML_UNARY_OP_GELU: op_str = "gelu"; break;
case GGML_UNARY_OP_GELU_ERF: op_str = "gelu_erf"; break;
case GGML_UNARY_OP_GELU_QUICK: op_str = "gelu_quick"; break;
case GGML_UNARY_OP_SILU: op_str = "silu"; break;
case GGML_UNARY_OP_ELU: op_str = "elu"; break;
case GGML_UNARY_OP_NEG: op_str = "neg"; break;
case GGML_UNARY_OP_ABS: op_str = "abs"; break;
case GGML_UNARY_OP_SGN: op_str = "sgn"; break;
case GGML_UNARY_OP_STEP: op_str = "step"; break;
case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break;
case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
case GGML_UNARY_OP_EXP: op_str = "exp"; break;
case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break;
case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break;
case GGML_UNARY_OP_TANH: op_num = OP_UNARY_NUM_TANH; break;
case GGML_UNARY_OP_RELU: op_num = OP_UNARY_NUM_RELU; break;
case GGML_UNARY_OP_SIGMOID: op_num = OP_UNARY_NUM_SIGMOID; break;
case GGML_UNARY_OP_GELU: op_num = OP_UNARY_NUM_GELU; break;
case GGML_UNARY_OP_GELU_ERF: op_num = OP_UNARY_NUM_GELU_ERF; break;
case GGML_UNARY_OP_GELU_QUICK: op_num = OP_UNARY_NUM_GELU_QUICK; break;
case GGML_UNARY_OP_SILU: op_num = OP_UNARY_NUM_SILU; break;
case GGML_UNARY_OP_ELU: op_num = OP_UNARY_NUM_ELU; break;
case GGML_UNARY_OP_NEG: op_num = OP_UNARY_NUM_NEG; break;
case GGML_UNARY_OP_ABS: op_num = OP_UNARY_NUM_ABS; break;
case GGML_UNARY_OP_SGN: op_num = OP_UNARY_NUM_SGN; break;
case GGML_UNARY_OP_STEP: op_num = OP_UNARY_NUM_STEP; break;
case GGML_UNARY_OP_HARDSWISH: op_num = OP_UNARY_NUM_HARDSWISH; break;
case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break;
case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break;
case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break;
case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break;
default: GGML_ABORT("fatal error");
} break;
default: GGML_ABORT("fatal error");
};
const char * suffix = "";
if (n % 4 == 0) {
suffix = "_4";
}
const char * t0_str = ggml_type_name(op->src[0]->type);
const char * t_str = ggml_type_name(op->type);
snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix);
snprintf(name, 256, "%s", base);
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768;
snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0);
ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
res.c4 = is_c4;
res.cnt = is_cnt;
return res;
}
@@ -1472,13 +1480,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_met
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_L2_NORM);
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
char base[256];
char name[256];
snprintf(base, 256, "kernel_l2_norm_f32");
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
const char * t0_str = ggml_type_name(op->src[0]->type);
const char * t_str = ggml_type_name(op->type);
snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
@@ -1486,6 +1496,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}
res.c4 = is_c4;
res.smem = 32*sizeof(float);
return res;

View File

@@ -1011,6 +1011,15 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
}
switch (op->op) {
case GGML_OP_SCALE:
case GGML_OP_FILL:
case GGML_OP_CLAMP:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
return ggml_is_contiguous_rows(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_TANH:
@@ -1030,7 +1039,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_EXPM1:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
return ggml_is_contiguous_rows(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
default:
return false;
}
@@ -1061,8 +1070,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
case GGML_OP_REPEAT:
case GGML_OP_SCALE:
case GGML_OP_FILL:
case GGML_OP_CONV_TRANSPOSE_1D:
return true;
case GGML_OP_CONV_TRANSPOSE_2D:
@@ -1070,14 +1077,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32;
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
case GGML_OP_TRI:
@@ -1087,9 +1086,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_COUNT_EQUAL:
return has_simdgroup_reduction &&
op->src[0]->type == GGML_TYPE_I32 &&

View File

@@ -80,7 +80,8 @@
#define FC_SSM_CONV 900
#define FC_SOLVE_TRI 1000
#define FC_COUNT_EQUAL 1100
#define FC_BIN 1200
#define FC_UNARY 1200
#define FC_BIN 1300
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPSG 8
@@ -89,6 +90,35 @@
#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1
#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
#define OP_UNARY_NUM_SCALE 10
#define OP_UNARY_NUM_FILL 11
#define OP_UNARY_NUM_CLAMP 12
#define OP_UNARY_NUM_SQR 13
#define OP_UNARY_NUM_SQRT 14
#define OP_UNARY_NUM_SIN 15
#define OP_UNARY_NUM_COS 16
#define OP_UNARY_NUM_LOG 17
#define OP_UNARY_NUM_LEAKY_RELU 18
#define OP_UNARY_NUM_TANH 100
#define OP_UNARY_NUM_RELU 101
#define OP_UNARY_NUM_SIGMOID 102
#define OP_UNARY_NUM_GELU 103
#define OP_UNARY_NUM_GELU_ERF 104
#define OP_UNARY_NUM_GELU_QUICK 105
#define OP_UNARY_NUM_SILU 106
#define OP_UNARY_NUM_ELU 107
#define OP_UNARY_NUM_NEG 108
#define OP_UNARY_NUM_ABS 109
#define OP_UNARY_NUM_SGN 110
#define OP_UNARY_NUM_STEP 111
#define OP_UNARY_NUM_HARDSWISH 112
#define OP_UNARY_NUM_HARDSIGMOID 113
#define OP_UNARY_NUM_EXP 114
#define OP_UNARY_NUM_SOFTPLUS 115
#define OP_UNARY_NUM_EXPM1 116
// kernel argument structs
//
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
@@ -124,6 +154,31 @@ typedef struct {
int32_t dim;
} ggml_metal_kargs_concat;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
float slope;
float scale;
float bias;
float val;
float min;
float max;
} ggml_metal_kargs_unary;
typedef struct {
int32_t ne00;
int32_t ne01;
@@ -181,20 +236,6 @@ typedef struct {
uint64_t nb3;
} ggml_metal_kargs_repeat;
typedef struct {
float scale;
float bias;
} ggml_metal_kargs_scale;
typedef struct {
float val;
} ggml_metal_kargs_fill;
typedef struct {
float min;
float max;
} ggml_metal_kargs_clamp;
typedef struct {
int64_t nk0;
int64_t ne00;
@@ -498,8 +539,21 @@ typedef struct {
typedef struct {
int32_t ne00;
int32_t ne00_4;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
float eps;
} ggml_metal_kargs_l2_norm;
@@ -881,10 +935,6 @@ typedef struct {
int max_period;
} ggml_metal_kargs_timestep_embedding;
typedef struct {
float slope;
} ggml_metal_kargs_leaky_relu;
typedef struct {
int32_t ne00;
int32_t ne01;

View File

@@ -287,17 +287,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
n_fuse = ggml_metal_op_acc(ctx, idx);
} break;
case GGML_OP_SCALE:
{
n_fuse = ggml_metal_op_scale(ctx, idx);
} break;
case GGML_OP_FILL:
{
n_fuse = ggml_metal_op_fill(ctx, idx);
} break;
case GGML_OP_CLAMP:
{
n_fuse = ggml_metal_op_clamp(ctx, idx);
} break;
case GGML_OP_LEAKY_RELU:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
@@ -426,10 +418,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_top_k(ctx, idx);
} break;
case GGML_OP_LEAKY_RELU:
{
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
} break;
case GGML_OP_TRI:
{
n_fuse = ggml_metal_op_tri(ctx, idx);
@@ -722,119 +710,6 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float scale;
float bias;
memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float));
memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float));
ggml_metal_kargs_scale args = {
/*.scale =*/ scale,
/*.bias =*/ bias,
};
int64_t n = ggml_nelements(op);
if (n % 4 == 0) {
n /= 4;
}
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const float val = ggml_get_op_params_f32(op, 0);
ggml_metal_kargs_fill args = {
/*.val =*/ val
};
int64_t n = ggml_nelements(op);
if (n % 4 == 0) {
n /= 4;
}
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float min;
float max;
memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float));
memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float));
ggml_metal_kargs_clamp args = {
/*.min =*/ min,
/*.max =*/ max,
};
int64_t n = ggml_nelements(op);
if (n % 4 == 0) {
n /= 4;
}
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
@@ -846,19 +721,79 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
int64_t n = ggml_nelements(op);
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
if (n % 4 == 0) {
n /= 4;
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_kargs_unary args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.slope =*/ 0.0,
/*.scale =*/ 0.0,
/*.bias =*/ 0.0,
/*.val =*/ 0.0,
/*.min =*/ 0.0,
/*.max =*/ 0.0,
};
if (op->op == GGML_OP_LEAKY_RELU) {
args.slope = ggml_get_op_params_f32(op, 0);
}
if (op->op == GGML_OP_SCALE) {
args.scale = ggml_get_op_params_f32(op, 0);
args.bias = ggml_get_op_params_f32(op, 1);
}
if (op->op == GGML_OP_FILL) {
args.val = ggml_get_op_params_f32(op, 0);
}
if (op->op == GGML_OP_CLAMP) {
args.min = ggml_get_op_params_f32(op, 0);
args.max = ggml_get_op_params_f32(op, 1);
}
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne0 = ne0/4;
}
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
if (pipeline.cnt) {
const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
} else {
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
const int nth = MIN(args.ne00, nth_max);
const int nk0 = (args.ne00 + nth - 1)/nth;
ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
}
return 1;
}
@@ -3044,39 +2979,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
float eps;
memcpy(&eps, op->op_params, sizeof(float));
int nth = 32; // SIMD width
ggml_metal_kargs_l2_norm args = {
/*.ne00 =*/ ne00,
/*.ne00_4 =*/ ne00/4,
/*.nb01 =*/ nb01,
/*.eps =*/ eps,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.eps =*/ eps,
};
auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne0 = ne0/4;
}
int nth = 32; // SIMD width
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00/4);
const size_t smem = pipeline.smem;
const int64_t nrows = ggml_nrows(op->src[0]);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
return 1;
}
@@ -4084,42 +4039,6 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float slope;
memcpy(&slope, op->op_params, sizeof(float));
ggml_metal_kargs_leaky_relu args = {
/*.slope =*/ slope
};
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
int64_t n = ggml_nelements(op);
if (n % 4 == 0) {
n /= 4;
}
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

View File

@@ -46,9 +46,6 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_fill (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
@@ -86,7 +83,6 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);

View File

@@ -895,6 +895,192 @@ enum ggml_sort_order {
GGML_SORT_ORDER_DESC,
};
constant float GELU_COEF_A = 0.044715f;
constant float GELU_QUICK_COEF = -1.702f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
constant float p_erf = 0.3275911f;
constant float a1_erf = 0.254829592f;
constant float a2_erf = -0.284496736f;
constant float a3_erf = 1.421413741f;
constant float a4_erf = -1.453152027f;
constant float a5_erf = 1.061405429f;
template<typename T>
T erf_approx(T x) {
T sign_x = sign(x);
x = fabs(x);
T t = 1.0f / (1.0f + p_erf * x);
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
return sign_x * y;
}
constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
template <typename T0, typename T>
kernel void kernel_unary_impl(
constant ggml_metal_kargs_unary & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
#define FC_OP FC_unary_op
#define FC_CNT FC_unary_cnt
device const T0 * src0_ptr;
device T * dst_ptr;
int i0;
if (FC_CNT) {
i0 = tgpig.x;
src0_ptr = (device const T0 *) (src0);
dst_ptr = (device T *) (dst);
} else {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int k0 = tgpig.x/args.ne01;
const int i01 = tgpig.x - k0*args.ne01;
i0 = k0*ntg.x + tpitg.x;
src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 );
}
{
//threadgroup_barrier(mem_flags::mem_none);
if (!FC_CNT) {
if (i0 >= args.ne0) {
return;
}
}
device const T0 & x = src0_ptr[i0];
if (FC_OP == OP_UNARY_NUM_SCALE) {
dst_ptr[i0] = args.scale * x + args.bias;
}
if (FC_OP == OP_UNARY_NUM_FILL) {
dst_ptr[i0] = args.val;
}
if (FC_OP == OP_UNARY_NUM_CLAMP) {
dst_ptr[i0] = clamp(x, args.min, args.max);
}
if (FC_OP == OP_UNARY_NUM_SQR) {
dst_ptr[i0] = x * x;
}
if (FC_OP == OP_UNARY_NUM_SQRT) {
dst_ptr[i0] = sqrt(x);
}
if (FC_OP == OP_UNARY_NUM_SIN) {
dst_ptr[i0] = sin(x);
}
if (FC_OP == OP_UNARY_NUM_COS) {
dst_ptr[i0] = cos(x);
}
if (FC_OP == OP_UNARY_NUM_LOG) {
dst_ptr[i0] = log(x);
}
if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
dst_ptr[i0] = T(x > 0.0f)*x + T(x <= 0.0f)*(x * args.slope);
}
if (FC_OP == OP_UNARY_NUM_TANH) {
dst_ptr[i0] = precise::tanh(x);
}
if (FC_OP == OP_UNARY_NUM_RELU) {
dst_ptr[i0] = fmax(0.0f, x);
}
if (FC_OP == OP_UNARY_NUM_SIGMOID) {
dst_ptr[i0] = 1.0f / (1.0f + exp(-x));
}
if (FC_OP == OP_UNARY_NUM_GELU) {
dst_ptr[i0] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
dst_ptr[i0] = 0.5f*x*(1.0f + erf_approx(SQRT_2_INV*x));
}
if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
dst_ptr[i0] = x * (1.0f/(1.0f + exp(GELU_QUICK_COEF*x)));
}
if (FC_OP == OP_UNARY_NUM_SILU) {
dst_ptr[i0] = x / (1.0f + exp(-x));
}
if (FC_OP == OP_UNARY_NUM_ELU) {
dst_ptr[i0] = T(x > 0.0f)*x + T(x <= 0.0f)*(exp(x) - 1.0f);
}
if (FC_OP == OP_UNARY_NUM_NEG) {
dst_ptr[i0] = -x;
}
if (FC_OP == OP_UNARY_NUM_ABS) {
dst_ptr[i0] = fabs(x);
}
if (FC_OP == OP_UNARY_NUM_SGN) {
dst_ptr[i0] = T(x > 0.0f) - T(x < 0.0f);
}
if (FC_OP == OP_UNARY_NUM_STEP) {
dst_ptr[i0] = T(x > 0.0f);
}
if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
dst_ptr[i0] = x * fmax(0.0f, fmin(1.0f, x/6.0f + 0.5f));
}
if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
dst_ptr[i0] = fmax(0.0f, fmin(1.0f, x/6.0f + 0.5f));
}
if (FC_OP == OP_UNARY_NUM_EXP) {
dst_ptr[i0] = exp(x);
}
if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
dst_ptr[i0] = select(log(1.0f + exp(x)), x, x > 20.0f);
}
if (FC_OP == OP_UNARY_NUM_EXPM1) {
// TODO: precise implementation
dst_ptr[i0] = exp(x) - 1.0f;
}
}
#undef FC_OP
#undef FC_CNT
}
typedef decltype(kernel_unary_impl<float, float>) kernel_unary_t;
template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float>;
template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4>;
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
@@ -1114,414 +1300,6 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
kernel void kernel_scale_f32(
constant ggml_metal_kargs_scale & args,
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * args.scale + args.bias;
}
kernel void kernel_scale_f32_4(
constant ggml_metal_kargs_scale & args,
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * args.scale + args.bias;
}
kernel void kernel_fill_f32(
constant ggml_metal_kargs_fill & args,
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = args.val;
}
kernel void kernel_fill_f32_4(
constant ggml_metal_kargs_fill & args,
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = args.val;
}
kernel void kernel_clamp_f32(
constant ggml_metal_kargs_clamp & args,
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = clamp(src0[tpig], args.min, args.max);
}
kernel void kernel_clamp_f32_4(
constant ggml_metal_kargs_clamp & args,
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = clamp(src0[tpig], args.min, args.max);
}
kernel void kernel_relu_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = max(0.0f, src0[tpig]);
}
kernel void kernel_relu_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = max(0.0f, src0[tpig]);
}
kernel void kernel_sigmoid_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
}
kernel void kernel_sigmoid_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
}
kernel void kernel_tanh_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = precise::tanh(src0[tpig]);
}
kernel void kernel_tanh_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = precise::tanh(src0[tpig]);
}
constant float GELU_COEF_A = 0.044715f;
constant float GELU_QUICK_COEF = -1.702f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
kernel void kernel_gelu_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
kernel void kernel_gelu_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];
// BEWARE !!!
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
// This was observed with Falcon 7B and 40B models
//
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
kernel void kernel_gelu_quick_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
}
kernel void kernel_gelu_quick_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
}
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
constant float p_erf = 0.3275911f;
constant float a1_erf = 0.254829592f;
constant float a2_erf = -0.284496736f;
constant float a3_erf = 1.421413741f;
constant float a4_erf = -1.453152027f;
constant float a5_erf = 1.061405429f;
template<typename T>
T erf_approx(T x) {
T sign_x = sign(x);
x = fabs(x);
T t = 1.0f / (1.0f + p_erf * x);
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
return sign_x * y;
}
kernel void kernel_gelu_erf_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
}
kernel void kernel_gelu_erf_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
}
kernel void kernel_silu_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = x / (1.0f + exp(-x));
}
kernel void kernel_silu_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];
dst[tpig] = x / (1.0f + exp(-x));
}
kernel void kernel_elu_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
const float x = src0[tpig];
dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
}
kernel void kernel_elu_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
const float4 x = src0[tpig];
dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
}
kernel void kernel_sqr_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src0[tpig];
}
kernel void kernel_sqr_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src0[tpig];
}
kernel void kernel_sqrt_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = sqrt(src0[tpig]);
}
kernel void kernel_sqrt_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = sqrt(src0[tpig]);
}
kernel void kernel_sin_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = sin(src0[tpig]);
}
kernel void kernel_sin_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = sin(src0[tpig]);
}
kernel void kernel_cos_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = cos(src0[tpig]);
}
kernel void kernel_cos_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = cos(src0[tpig]);
}
kernel void kernel_log_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = log(src0[tpig]);
}
kernel void kernel_log_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = log(src0[tpig]);
}
kernel void kernel_neg_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = -src0[tpig];
}
kernel void kernel_neg_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = -src0[tpig];
}
kernel void kernel_abs_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = fabs(src0[tpig]);
}
kernel void kernel_abs_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = fabs(src0[tpig]);
}
kernel void kernel_sgn_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = sign(src0[tpig]);
}
kernel void kernel_sgn_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = sign(src0[tpig]);
}
kernel void kernel_step_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = step(0.0f, src0[tpig]);
}
kernel void kernel_step_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = step(0.0f, src0[tpig]);
}
kernel void kernel_hardswish_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
const float x = src0[tpig];
dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
}
kernel void kernel_hardswish_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
const float4 x = src0[tpig];
dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
}
kernel void kernel_hardsigmoid_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
const float x = src0[tpig];
dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
}
kernel void kernel_hardsigmoid_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
const float4 x = src0[tpig];
dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
}
kernel void kernel_exp_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = exp(src0[tpig]);
}
kernel void kernel_exp_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = exp(src0[tpig]);
}
kernel void kernel_softplus_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
}
kernel void kernel_softplus_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
}
kernel void kernel_expm1_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = exp(src0[tpig]) - 1.0f;
}
kernel void kernel_expm1_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = exp(src0[tpig]) - 1.0f;
}
kernel void kernel_reglu_f32(
constant ggml_metal_kargs_glu & args,
device const char * src0,
@@ -2928,26 +2706,32 @@ template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_f
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
kernel void kernel_l2_norm_f32(
template <typename T0, typename T>
kernel void kernel_l2_norm_impl(
constant ggml_metal_kargs_l2_norm & args,
device const char * src0,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort ntg[[threads_per_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int i01 = tgpig.x;
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
float sumf = 0.0f;
// parallel sum
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
@@ -2965,12 +2749,16 @@ kernel void kernel_l2_norm_f32(
const float scale = 1.0f/sqrt(max(sumf, args.eps));
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
y[i00] = x[i00] * scale;
}
}
typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,
@@ -5072,24 +4860,6 @@ kernel void kernel_argsort_merge_f32_i32(
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
kernel void kernel_leaky_relu_f32(
constant ggml_metal_kargs_leaky_relu & args,
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
const float x = src0[tpig];
dst[tpig] = x > 0.0f ? x : x * args.slope;
}
kernel void kernel_leaky_relu_f32_4(
constant ggml_metal_kargs_leaky_relu & args,
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
const float4 x = src0[tpig];
dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
}
constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
@@ -9939,7 +9709,7 @@ kernel void kernel_opt_step_sgd_f32(
template<typename T>
kernel void kernel_memset(
constant ggml_metal_kargs_fill & args,
constant ggml_metal_kargs_memset & args,
device T * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = args.val;

View File

@@ -117,49 +117,40 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
GGML_ASSERT(k->ne[2] == n_tokens);
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
const float eps_norm = hparams.f_norm_rms_eps;
q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);
const float scale = 1.0f / sqrtf(S_v);
q = ggml_scale(ctx0, q, scale);
beta = ggml_sigmoid(ctx0, beta);
cb(q, "q_in", il);
cb(k, "k_in", il);
cb(v, "v_in", il);
cb(beta, "beta_in", il);
cb(g, "g_in", il);
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
g = ggml_permute(ctx0, g, 2, 1, 3, 0);
beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
beta = ggml_permute(ctx0, beta, 2, 0, 1, 3);
cb(q, "q_perm", il);
cb(k, "k_perm", il);
cb(v, "v_perm", il);
cb(beta, "beta_perm", il);
cb(g, "g_perm", il);
cb(state, "state_in", il);
GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_v && v->ne[3] == n_seqs);
GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_v && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
// Do padding
const int64_t chunk_size = CHUNK_SIZE;
@@ -170,7 +161,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
q = ggml_pad(ctx0, q, 0, pad, 0, 0);
k = ggml_pad(ctx0, k, 0, pad, 0, 0);
v = ggml_pad(ctx0, v, 0, pad, 0, 0);
g = ggml_pad(ctx0, g, pad, 0, 0, 0);
g = ggml_pad(ctx0, g, 0, pad, 0, 0);
beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
cb(q, "q_pad", il);
@@ -187,12 +178,12 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs);
k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs);
k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_v * n_seqs);
v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs);
v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_v * n_seqs);
beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_v * n_seqs);
ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
@@ -208,7 +199,6 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
decay_mask = ggml_exp(ctx0, decay_mask);
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
@@ -216,24 +206,22 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn), attn);
ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
attn = ggml_mul(ctx0, lin_solve, causal_mask);
attn = ggml_add(ctx0, attn, identity);
attn = ggml_add(ctx0, lin_solve, identity);
cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
v = ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)));
ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t);
ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum);
k_beta = ggml_cont(ctx0, ggml_transpose(ctx0, k_beta));
ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
ggml_tensor * k_cumdecay =
ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
ggml_tensor * k_cumdecay = ggml_mul_mat(ctx0, kbeta_gexp, attn);
cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q);
@@ -241,7 +229,6 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
// vectorized calculation of key_gdiff
// improved from the chunked version:
// g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
@@ -264,9 +251,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last));
cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp,
1, chunk_size, n_chunks, g_diff_exp->ne[3]);
ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
ggml_tensor * g_diff_exp_t = ggml_transpose(ctx0, g_diff_exp);
ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t);
cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
@@ -274,14 +260,11 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff));
cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs)
// state to be updated per chunk
ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs)
// shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs)
ggml_tensor * core_attn_out = nullptr;
state = ggml_cont_4d(ctx0, ggml_transpose(ctx0, state), S_v, S_v, 1, H_v * n_seqs);
for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
// shape: (S_k, chunk_size, 1, H_k * n_seqs)
ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul
@@ -300,20 +283,17 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk);
cb(attn_chunk, "attn_chunk", il);
ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
// v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
ggml_tensor * v_prime = ggml_mul_mat(ctx0, k_cumdecay_chunk, state);
cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs)
// v_new = v_i - v_prime
ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
cb(v_new, "v_new_chunk", il);
ggml_tensor * v_new_t = ggml_sub(ctx0, v_chunk, v_prime);
cb(v_new_t, "v_new_chunk_t", il);
// attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk);
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, ggml_transpose(ctx0, gexp_chunk));
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state, q_g_exp);
cb(attn_inter, "attn_inter_chunk", il);
// core_attn_out[:, :, i] = attn_inter + attn @ v_new
@@ -329,30 +309,30 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
// kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk);
//ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t);
ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff_t, v_new_t);
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));
new_state = ggml_add(ctx0,
ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
ggml_tensor * gexp_last_chunk = get_slice_2d(ctx0, g_last_exp, chunk);
state = ggml_mul(ctx0, state, gexp_last_chunk);
state = ggml_add(ctx0, state, kgdmulvnew);
}
state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
state = ggml_transpose(ctx0, state);
// truncate padded tokens
ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
S_v, n_tokens, H_v, n_seqs,
ggml_row_size(core_attn_out->type, S_v),
ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
output_tokens = ggml_cont(ctx0, output_tokens);
cb(output_tokens, "output_tokens", il);
// permute back to (S_v, H_v, n_tokens, n_seqs)
output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
output_tokens = ggml_cont(ctx0, output_tokens);
return {output_tokens, new_state};
return {output_tokens, state};
}
std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_autoregressive(
@@ -360,8 +340,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_aut
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * g,
ggml_tensor * beta,
ggml_tensor * state,
ggml_tensor * b, // beta
ggml_tensor * s, // state
int il) {
const int64_t S_k = q->ne[0];
const int64_t H_k = q->ne[1];
@@ -375,71 +355,55 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_aut
GGML_ASSERT(v->ne[2] == n_tokens);
GGML_ASSERT(k->ne[2] == n_tokens);
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
const float eps_norm = hparams.f_norm_rms_eps;
const float scale = 1.0f / sqrtf(S_k);
q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);
q = ggml_scale(ctx0, q, scale);
const float scale = 1.0f / sqrtf(S_v);
q = ggml_scale(ctx0, q, scale);
beta = ggml_sigmoid(ctx0, beta);
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
cb(q, "q_in", il);
cb(k, "k_in", il);
cb(v, "v_in", il);
cb(beta, "beta_in", il);
cb(b, "b_in", il);
cb(g, "g_in", il);
state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
ggml_tensor * g_t = ggml_reshape_4d(ctx0, g, 1, 1, H_v, n_seqs);
ggml_tensor * b_t = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs);
ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
// Apply exponential to g_t
g_t = ggml_exp(ctx0, g_t);
s = ggml_mul(ctx0, s, g_t);
// Apply the gated delta rule for the single timestep
// last_recurrent_state = last_recurrent_state * g_t
state = ggml_mul(ctx0, state, g_t);
ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s));
// kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed);
// we need to sum over dim=-2, so we transpose, sum, then transpose again
kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
ggml_tensor * kv_mem = ggml_mul(ctx0, s_t, k);
kv_mem = ggml_sum_rows(ctx0, kv_mem);
// v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v)
ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
// delta = (v_t - kv_mem) * beta_t
ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs]
ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t);
ggml_tensor * v_diff = ggml_sub(ctx0, v, ggml_transpose(ctx0, kv_mem));
ggml_tensor * delta = ggml_mul(ctx0, v_diff, b_t);
// last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta
ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
state = ggml_add(ctx0, state, k_t_delta);
ggml_tensor * k_d = ggml_mul(ctx0, ggml_repeat(ctx0, k, s), ggml_transpose(ctx0, delta));
s_t = ggml_add(ctx0, s_t, k_d);
// Compute the attention output
// core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t
ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed);
// again, since it's over dim = -2, transpose, sum, transpose back
ggml_tensor * core_attn_out =
ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
ggml_tensor * s_q = ggml_mul(ctx0, s_t, q);
ggml_tensor * output = ggml_sum_rows(ctx0, s_q);
// core_attn_out should be [S_v, 1, H_v, n_seqs] after this
cb(core_attn_out, "output_tokens", il);
cb(state, "new_state", il);
output = ggml_permute(ctx0, output, 2, 0, 1, 3);
s = ggml_transpose(ctx0, s_t);
return {core_attn_out, state};
cb(output, "output_tokens", il);
cb(s, "new_state", il);
return {output, s};
}
ggml_tensor * llm_build_qwen3next::build_norm_gated(
@@ -473,15 +437,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
// The split should be along dimension 0 (the feature dimension)
ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
ggml_tensor * gate =
ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full));
cb(Qcur, "Qcur", il);
cb(gate, "gate", il);
// Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention
Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
cb(Qcur, "Qcur_reshaped", il);
// Apply Q normalization
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
@@ -495,13 +450,10 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
// Apply K normalization
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
cb(Kcur, "Kcur_normed", il);
// Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
cb(gate, "gate_reshaped", il);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
// Apply RoPE
@@ -527,10 +479,21 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
cb(cur, "attn_pregate", il);
ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
cb(gate_sigmoid, "gate_sigmoid", il);
ggml_tensor * gate =
ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full));
cb(Qcur, "Qcur", il);
cb(gate, "gate", il);
cur = ggml_mul(ctx0, cur, gate_sigmoid);
// TODO: CUDA is missing non-contiguous unary ops. when implemented: remove this cont
gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
gate = ggml_sigmoid(ctx0, gate);
cb(gate, "gate_sigmoid", il);
gate = ggml_reshape_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
cur = ggml_mul(ctx0, cur, gate);
cb(cur, "attn_gated", il);
cur = build_lora_mm(model.layers[il].wo, cur);
@@ -671,7 +634,12 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
cb(a, "a", il);
ggml_tensor * beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
// TODO: CUDA is missing non-contiguous unary ops. when implemented: remove this cont
b = ggml_cont(ctx0, b);
ggml_tensor * beta = ggml_sigmoid(ctx0, b);
beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs);
// Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
@@ -679,6 +647,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
cb(alpha_softplus, "a_softplus", il);
ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus
cb(gate, "gate", il);
@@ -696,11 +665,11 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
const int64_t conv_kernel_size = conv_kernel->ne[0];
const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
cb(conv_states, "conv_states_reshaped", il);
qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
cb(qkv_mixed, "qkv_mixed_permuted", il);
qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
cb(qkv_mixed, "qkv_mixed_transposed", il);
ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
cb(conv_input, "conv_input", il);
@@ -734,25 +703,40 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
// Extract the convolved Q, K, V from conv_output
ggml_tensor * q_conv =
ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0);
ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
ggml_row_size(conv_qkv_mix->type, head_k_dim),
nb1_qkv,
nb1_qkv * n_seq_tokens,
0);
ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
ggml_row_size(conv_qkv_mix->type, head_k_dim),
nb1_qkv,
nb1_qkv * n_seq_tokens,
head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs,
ggml_row_size(conv_qkv_mix->type, head_v_dim),
nb1_qkv,
nb1_qkv * n_seq_tokens,
ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads));
cb(q_conv, "q_conv", il);
ggml_tensor * k_conv =
ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv,
head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
cb(k_conv, "k_conv", il);
ggml_tensor * v_conv =
ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv,
2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
cb(v_conv, "v_conv", il);
const float eps_norm = hparams.f_norm_rms_eps;
q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm);
k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm);
// Unsqueeze them
q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
//q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
//k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
cb(state, "state_predelta", il);
// if head keys and value keys are different, repeat to force tensors into matching shapes
@@ -795,19 +779,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
// Update the recurrent states
ggml_build_forward_expand(gf,
ggml_cpy(ctx0, new_state,
ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
// Reshape both attn_out_final and z to 2D tensors for normalization
// attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
ggml_cpy(ctx0, new_state,
ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
// Apply gated normalization: self.norm(core_attn_out, z)
ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il);
// Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
@@ -818,7 +798,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
cb(cur, "linear_attn_out", il);
// Reshape back to original dimensions
cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
return cur;
}

View File

@@ -2964,11 +2964,12 @@ struct test_bin_bcast : public test_case {
const std::array<int64_t, 4> ne;
const std::array<int, 4> nr;
int nf; // number of fused ops, nf == 1 -> single op (no fusion)
bool perm1; // permute src1?
bool run_whole_graph() override { return nf > 1; }
std::string vars() override {
return VARS_TO_STR4(type, ne, nr, nf);
return VARS_TO_STR5(type, ne, nr, nf, perm1);
}
size_t op_size(ggml_tensor * t) override {
@@ -2978,8 +2979,9 @@ struct test_bin_bcast : public test_case {
test_bin_bcast(op_t op, ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 1, 1},
std::array<int, 4> nr = {1, 2, 1, 1},
int nf = 1)
: op(op), type(type), ne(ne), nr(nr), nf(nf) {}
int nf = 1,
bool perm1 = false)
: op(op), type(type), ne(ne), nr(nr), nf(nf), perm1(perm1) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
GGML_ASSERT(nf <= 16);
@@ -2989,12 +2991,19 @@ struct test_bin_bcast : public test_case {
ggml_tensor * b[16];
for (int i = 0; i < nf; ++i) {
b[i] = ggml_new_tensor(ctx, type, 4, ne.data());
if (perm1) {
const int p[4] = { 1, 2, 0, 3 }; // hardcoded for now
b[i] = ggml_new_tensor_4d(ctx, type, ne[p[0]], ne[p[1]], ne[p[2]], ne[p[3]]);
b[i] = ggml_permute(ctx, b[i], p[0], p[1], p[2], p[3]);
} else {
b[i] = ggml_new_tensor(ctx, type, 4, ne.data());
}
ggml_set_name(b[i], (std::string("b") + std::to_string(i)).c_str());
}
// The backward pass supports broadcasting only for GGML_ADD:
const bool grad_supported = op == ggml_add && ggml_are_same_shape(a, b[0]) && nf == 1;
const bool grad_supported = op == ggml_add && ggml_are_same_shape(a, b[0]) && nf == 1 && !perm1;
if (grad_supported) {
ggml_set_param(a);
ggml_set_param(b[0]);
@@ -7477,25 +7486,27 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr, bool perm1 = false) {
for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr, 1, perm1));
}
};
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
add_test_bin_bcast(type, {1, 1, 8, 1}, {1, 1, 1, 1});
add_test_bin_bcast(type, {1, 1, 1, 1}, {32, 1, 1, 1});
add_test_bin_bcast(type, {1, 1, 320, 320}, {1, 1, 1, 1});
add_test_bin_bcast(type, {10, 5, 1, 1}, {1, 1, 1, 1});
add_test_bin_bcast(type, {10, 5, 4, 1}, {1, 1, 1, 1});
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 1});
add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 1, 1, 1});
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 1, 1});
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 1});
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 2});
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 2});
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2});
add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2});
for (bool perm1 : {false, true}) {
add_test_bin_bcast(type, {1, 1, 8, 1}, {1, 1, 1, 1}, perm1);
add_test_bin_bcast(type, {1, 1, 1, 1}, {32, 1, 1, 1}, perm1);
add_test_bin_bcast(type, {1, 1, 320, 320}, {1, 1, 1, 1}, perm1);
add_test_bin_bcast(type, {10, 5, 1, 1}, {1, 1, 1, 1}, perm1);
add_test_bin_bcast(type, {10, 5, 4, 1}, {1, 1, 1, 1}, perm1);
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 1}, perm1);
add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 1, 1, 1}, perm1);
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 1, 1}, perm1);
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 1}, perm1);
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 2}, perm1);
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 2}, perm1);
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2}, perm1);
add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2}, perm1);
}
// test case for k_bin_bcast_unravel in CUDA backend
add_test_bin_bcast(type, {1, 1, 65536, 1}, {256, 1, 1, 1});
@@ -7882,20 +7893,27 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_round (type));
test_cases.emplace_back(new test_trunc (type));
test_cases.emplace_back(new test_sqr (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_sqr (type, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_sqrt (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_sqrt (type, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_log (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_log (type, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_sin (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_sin (type, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_cos (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_cos (type, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_clamp (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_clamp (type, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_leaky_relu(type, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_floor (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_floor (type, { 1024, 1024, 1, 1 }));
test_cases.emplace_back(new test_floor (type, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_ceil (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_ceil (type, { 1024, 1024, 1, 1 }));
test_cases.emplace_back(new test_ceil (type, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_round (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_round (type, { 1024, 1024, 1, 1 }));
test_cases.emplace_back(new test_round (type, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_trunc (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_trunc (type, { 1024, 1024, 1, 1 }));
test_cases.emplace_back(new test_trunc (type, {1024, 1024, 1, 1}));
}
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));