mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-05 13:53:23 +02:00
metal : add diag (#19330)
This commit is contained in:
@@ -176,6 +176,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_me
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
const int n = op->src[0]->ne[0];
|
||||
|
||||
snprintf(base, 256, "kernel_diag_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s_n=%d", base, n);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
}
|
||||
|
||||
res.nsg = 1;
|
||||
res.smem = 0;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) {
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
@@ -108,6 +108,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
@@ -1152,8 +1152,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
return true;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return has_simdgroup_reduction;
|
||||
@@ -1235,6 +1235,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return false;
|
||||
};
|
||||
}
|
||||
case GGML_OP_DIAG:
|
||||
return true;
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
return has_simdgroup_reduction;
|
||||
|
||||
@@ -792,6 +792,25 @@ typedef struct {
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_set_rows;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_diag;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
|
||||
@@ -361,6 +361,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_set_rows(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_DIAG:
|
||||
{
|
||||
n_fuse = ggml_metal_op_diag(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_L2_NORM:
|
||||
{
|
||||
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
|
||||
@@ -1259,6 +1263,48 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS(int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_kargs_diag args = {
|
||||
/*.ne00 =*/ne00,
|
||||
/*.ne01 =*/ne01,
|
||||
/*.ne02 =*/ne02,
|
||||
/*.ne03 =*/ne03,
|
||||
/*.nb00 =*/nb00,
|
||||
/*.nb01 =*/nb01,
|
||||
/*.nb02 =*/nb02,
|
||||
/*.nb03 =*/nb03,
|
||||
/*.ne0 =*/ne0,
|
||||
/*.ne1 =*/ne1,
|
||||
/*.ne2 =*/ne2,
|
||||
/*.ne3 =*/ne3,
|
||||
/*.nb0 =*/nb0,
|
||||
/*.nb1 =*/nb1,
|
||||
/*.nb2 =*/nb2,
|
||||
/*.nb3 =*/nb3,
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@ int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_diag (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
@@ -8815,6 +8815,26 @@ kernel void kernel_set_rows_f(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_diag_f32(
|
||||
constant ggml_metal_kargs_diag & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]]) {
|
||||
constexpr short NW = N_SIMDWIDTH;
|
||||
|
||||
const int32_t i3 = tgpig.z;
|
||||
const int32_t i2 = tgpig.y;
|
||||
const int32_t i1 = tgpig.x;
|
||||
|
||||
device const float * src0_ptr = (device const float *)(src0 + i2*args.nb02 + i3*args.nb03);
|
||||
device float * dst_ptr = (device float *)(dst + i1*args.nb01 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {
|
||||
dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
|
||||
constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
|
||||
|
||||
|
||||
Reference in New Issue
Block a user