mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-12 14:03:20 +02:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4ea7a188f | ||
|
|
7a4f97d196 | ||
|
|
a498c75ad1 | ||
|
|
3409ab842d | ||
|
|
c342c3b93d | ||
|
|
af252d0758 |
@@ -176,6 +176,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_me
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
const int n = op->src[0]->ne[0];
|
||||
|
||||
snprintf(base, 256, "kernel_diag_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s_n=%d", base, n);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
}
|
||||
|
||||
res.nsg = 1;
|
||||
res.smem = 0;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) {
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
@@ -108,6 +108,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
@@ -1152,8 +1152,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
return true;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return has_simdgroup_reduction;
|
||||
@@ -1235,6 +1235,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return false;
|
||||
};
|
||||
}
|
||||
case GGML_OP_DIAG:
|
||||
return true;
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
return has_simdgroup_reduction;
|
||||
|
||||
@@ -792,6 +792,25 @@ typedef struct {
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_set_rows;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_diag;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
|
||||
@@ -361,6 +361,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_set_rows(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_DIAG:
|
||||
{
|
||||
n_fuse = ggml_metal_op_diag(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_L2_NORM:
|
||||
{
|
||||
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
|
||||
@@ -1259,6 +1263,48 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS(int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_kargs_diag args = {
|
||||
/*.ne00 =*/ne00,
|
||||
/*.ne01 =*/ne01,
|
||||
/*.ne02 =*/ne02,
|
||||
/*.ne03 =*/ne03,
|
||||
/*.nb00 =*/nb00,
|
||||
/*.nb01 =*/nb01,
|
||||
/*.nb02 =*/nb02,
|
||||
/*.nb03 =*/nb03,
|
||||
/*.ne0 =*/ne0,
|
||||
/*.ne1 =*/ne1,
|
||||
/*.ne2 =*/ne2,
|
||||
/*.ne3 =*/ne3,
|
||||
/*.nb0 =*/nb0,
|
||||
/*.nb1 =*/nb1,
|
||||
/*.nb2 =*/nb2,
|
||||
/*.nb3 =*/nb3,
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@ int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_diag (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
@@ -7,6 +7,9 @@
|
||||
#include "ggml-metal-context.h"
|
||||
#include "ggml-metal-ops.h"
|
||||
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
|
||||
#define GGML_METAL_NAME "MTL"
|
||||
#define GGML_METAL_MAX_DEVICES 16
|
||||
|
||||
|
||||
@@ -8815,6 +8815,26 @@ kernel void kernel_set_rows_f(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_diag_f32(
|
||||
constant ggml_metal_kargs_diag & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]]) {
|
||||
constexpr short NW = N_SIMDWIDTH;
|
||||
|
||||
const int32_t i3 = tgpig.z;
|
||||
const int32_t i2 = tgpig.y;
|
||||
const int32_t i1 = tgpig.x;
|
||||
|
||||
device const float * src0_ptr = (device const float *)(src0 + i2*args.nb02 + i3*args.nb03);
|
||||
device float * dst_ptr = (device float *)(dst + i1*args.nb01 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {
|
||||
dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
|
||||
constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
|
||||
|
||||
|
||||
@@ -1263,25 +1263,30 @@ struct vk_op_diag_mask_push_constants {
|
||||
|
||||
struct vk_op_rope_push_constants {
|
||||
uint32_t rope_mode;
|
||||
uint32_t ncols;
|
||||
uint32_t nrows;
|
||||
uint32_t n_dims;
|
||||
float freq_scale;
|
||||
uint32_t p_delta_rows;
|
||||
float freq_base;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float corr_dims[2];
|
||||
float theta_scale;
|
||||
uint32_t has_ff;
|
||||
uint32_t ne02;
|
||||
uint32_t s1;
|
||||
uint32_t s2;
|
||||
int32_t sections[4];
|
||||
uint32_t is_imrope;
|
||||
uint32_t is_back;
|
||||
uint32_t set_rows_stride;
|
||||
uint32_t ne00;
|
||||
uint32_t ne01;
|
||||
uint32_t ne02;
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
};
|
||||
static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128");
|
||||
|
||||
// For fused rms_norm+mul+rope(+view+set_rows)
|
||||
struct vk_op_rms_norm_mul_rope_push_constants {
|
||||
@@ -3199,9 +3204,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
const uint32_t D_lsb = D ^ (D & (D-1));
|
||||
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
|
||||
|
||||
// Nvidia prefers shared memory use to load large tiles of K
|
||||
// Nvidia prefers shared memory use to load large tiles of K.
|
||||
// Switch to loading from global memory when it would use too much shared memory.
|
||||
// AMD prefers loading K directly from global memory
|
||||
const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA ? 1 : 0;
|
||||
const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
|
||||
|
||||
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem};
|
||||
};
|
||||
@@ -5555,9 +5561,9 @@ static void ggml_vk_instance_init() {
|
||||
// Check if there are two physical devices corresponding to the same GPU
|
||||
// This handles the case where the same GPU appears with different drivers (e.g., RADV + AMDVLK on Linux),
|
||||
// see https://github.com/ggml-org/llama.cpp/pull/7582 for original deduplication.
|
||||
// However, for MoltenVK on macOS, multiple GPUs on the same card may report the same UUID,
|
||||
// see https://github.com/KhronosGroup/MoltenVK/issues/2683. Until this is fixed, we'll only deduplicate
|
||||
// when drivers differ (same driver + same UUID = likely different GPUs)
|
||||
// MoltenVK on macOS may report the same UUID for distinct GPUs on multi-GPU cards,
|
||||
// see https://github.com/KhronosGroup/MoltenVK/issues/2683. Skip when both old/new
|
||||
// driver is MoltenVK
|
||||
auto old_device = std::find_if(
|
||||
vk_instance.device_indices.begin(),
|
||||
vk_instance.device_indices.end(),
|
||||
@@ -5574,11 +5580,9 @@ static void ggml_vk_instance_init() {
|
||||
old_id.deviceLUIDValid && new_id.deviceLUIDValid &&
|
||||
std::equal(std::begin(old_id.deviceLUID), std::end(old_id.deviceLUID), std::begin(new_id.deviceLUID))
|
||||
);
|
||||
bool both_molten_vk = (new_driver.driverID == vk::DriverId::eMoltenvk && old_driver.driverID == vk::DriverId::eMoltenvk);
|
||||
|
||||
// Only deduplicate if same UUID AND different drivers
|
||||
// (same driver + same UUID on MoltenVK = likely different GPUs on multi-GPU card)
|
||||
bool different_driver = (old_driver.driverID != new_driver.driverID);
|
||||
return same_uuid && different_driver;
|
||||
return same_uuid && !both_molten_vk;
|
||||
}
|
||||
);
|
||||
if (old_device == vk_instance.device_indices.end()) {
|
||||
@@ -8407,7 +8411,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
||||
const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
|
||||
const uint32_t sfsh = Bc * sfshstride * acctype;
|
||||
|
||||
const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA;
|
||||
const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256;
|
||||
const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2;
|
||||
const uint32_t vsh_stride = MatBc / 4 * row_split;
|
||||
const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4;
|
||||
@@ -10405,12 +10409,22 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *
|
||||
|
||||
uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type);
|
||||
uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
|
||||
uint32_t nb03 = src0->nb[3] / ggml_type_size(src0->type);
|
||||
|
||||
uint32_t nb11 = dst->nb[1] / ggml_type_size(dst->type);
|
||||
uint32_t nb12 = dst->nb[2] / ggml_type_size(dst->type);
|
||||
uint32_t nb13 = dst->nb[3] / ggml_type_size(dst->type);
|
||||
|
||||
vk_op_rope_push_constants rope {
|
||||
(uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
|
||||
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
|
||||
has_ff, (uint32_t)src0->ne[2], nb01, nb02,
|
||||
(uint32_t)mode, (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, has_ff,
|
||||
{ sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
|
||||
|
||||
(uint32_t)src0->ne[0],
|
||||
(uint32_t)src0->ne[1],
|
||||
(uint32_t)src0->ne[2],
|
||||
nb01, nb02, nb03,
|
||||
nb11, nb12, nb13,
|
||||
};
|
||||
|
||||
return rope;
|
||||
@@ -14798,6 +14812,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ROPE:
|
||||
return ggml_is_contiguous_rows(op) && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_ROPE_BACK:
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
|
||||
@@ -112,12 +112,11 @@ void rms_norm(uint num_iters) {
|
||||
#if RMS_NORM_ROPE_FUSION
|
||||
barrier();
|
||||
rope_params rp = p.rope;
|
||||
uint rope_row = (samp*nchannels + channel)*nrows + row;
|
||||
for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) {
|
||||
if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) {
|
||||
rope_neox(t, rope_row, rp);
|
||||
rope_neox(t, row, channel, samp, rp);
|
||||
} else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) {
|
||||
rope_norm(t, rope_row, rp);
|
||||
rope_norm(t, row, channel, samp, rp);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -4,12 +4,12 @@ float rope_yarn_ramp(const float low, const float high, const uint i0) {
|
||||
return 1.0f - min(1.0f, max(0.0f, y));
|
||||
}
|
||||
|
||||
uint rope_a_coord(const uint i0, const uint i01, const uint i02, rope_params p) {
|
||||
uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03, rope_params p) {
|
||||
#if RMS_NORM_ROPE_FUSION
|
||||
// Per-row offset in shared memory
|
||||
const uint ix = i0;
|
||||
#else
|
||||
const uint ix = i02*p.nb02 + i01*p.nb01 + i0;
|
||||
const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0;
|
||||
#endif
|
||||
return ix;
|
||||
}
|
||||
@@ -34,26 +34,19 @@ void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out
|
||||
sin_theta = sin(theta) * mscale;
|
||||
}
|
||||
|
||||
void rope_norm(const uint i0, const uint i1, rope_params p) {
|
||||
uint ne0 = p.ncols;
|
||||
uint ne1 = p.p_delta_rows;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
|
||||
if (i0 >= p.ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
// i1 is actually i2*nb2+i1, but the rows are contiguous
|
||||
const uint i01 = i1 % ne1;
|
||||
const uint i02 = i1 / ne1;
|
||||
|
||||
uint idst = i1*ne0 + i0;
|
||||
const uint ix = rope_a_coord(i0, i01, i02, p);
|
||||
uint idst = i0 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
|
||||
const uint ix = rope_a_coord(i0, i1, i2, i3, p);
|
||||
|
||||
// Fusion optimization: ROPE + VIEW + SET_ROWS.
|
||||
// The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
|
||||
if (p.set_rows_stride != 0) {
|
||||
idst = i01*ne0 + i0;
|
||||
idst += rope_data_i[i02].x * p.set_rows_stride;
|
||||
idst = i1*p.nb11 + i0;
|
||||
idst += rope_data_i[i2].x * p.set_rows_stride;
|
||||
}
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
@@ -63,7 +56,7 @@ void rope_norm(const uint i0, const uint i1, rope_params p) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
|
||||
const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
|
||||
|
||||
@@ -77,25 +70,19 @@ void rope_norm(const uint i0, const uint i1, rope_params p) {
|
||||
rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
|
||||
void rope_neox(const uint i0, const uint i1, rope_params p) {
|
||||
uint ne0 = p.ncols;
|
||||
uint ne1 = p.p_delta_rows;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
|
||||
if (i0 >= p.ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i01 = i1 % ne1;
|
||||
const uint i02 = i1 / ne1;
|
||||
|
||||
uint idst = i1*ne0 + i0/2;
|
||||
const uint ix = rope_a_coord(i0/2, i01, i02, p);
|
||||
uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
|
||||
const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
|
||||
|
||||
// Fusion optimization: ROPE + VIEW + SET_ROWS.
|
||||
// The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
|
||||
if (p.set_rows_stride != 0) {
|
||||
idst = i01*ne0 + i0/2;
|
||||
idst += rope_data_i[i02].x * p.set_rows_stride;
|
||||
idst = i1*p.nb11 + i0/2;
|
||||
idst += rope_data_i[i2].x * p.set_rows_stride;
|
||||
}
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
@@ -105,7 +92,7 @@ void rope_neox(const uint i0, const uint i1, rope_params p) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
|
||||
const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
|
||||
|
||||
@@ -120,26 +107,19 @@ void rope_neox(const uint i0, const uint i1, rope_params p) {
|
||||
}
|
||||
|
||||
|
||||
void rope_multi(const uint i0, const uint i1, rope_params p) {
|
||||
uint ne0 = p.ncols;
|
||||
uint ne1 = p.p_delta_rows;
|
||||
uint ne2 = p.ne02;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
|
||||
if (i0 >= p.ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i01 = i1 % ne1;
|
||||
const uint i02 = i1 / ne1;
|
||||
|
||||
uint idst = i1*ne0 + i0/2;
|
||||
const uint ix = rope_a_coord(i0/2, i01, i02, p);
|
||||
uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
|
||||
const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
|
||||
|
||||
// Fusion optimization: ROPE + VIEW + SET_ROWS.
|
||||
// The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
|
||||
if (p.set_rows_stride != 0) {
|
||||
idst = i01*ne0 + i0/2;
|
||||
idst += rope_data_i[i02].x * p.set_rows_stride;
|
||||
idst = i1*p.nb11 + i0/2;
|
||||
idst += rope_data_i[i2].x * p.set_rows_stride;
|
||||
}
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
@@ -156,26 +136,26 @@ void rope_multi(const uint i0, const uint i1, rope_params p) {
|
||||
float theta_base = 0.0;
|
||||
if (p.is_imrope != 0) {
|
||||
if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
|
||||
theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||
theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
|
||||
theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||
theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
|
||||
theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
|
||||
theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
|
||||
} else {
|
||||
theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||
theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
} else {
|
||||
if (sector < p.sections[0]) {
|
||||
theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
|
||||
theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= p.sections[0] && sector < sec_w) {
|
||||
theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||
theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
|
||||
theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||
theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w + p.sections[2]) {
|
||||
theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||
theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,20 +171,13 @@ void rope_multi(const uint i0, const uint i1, rope_params p) {
|
||||
rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
|
||||
void rope_vision(const uint i0, const uint i1, rope_params p) {
|
||||
uint ne0 = p.ncols;
|
||||
uint ne1 = p.p_delta_rows;
|
||||
uint ne2 = p.ne02;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
|
||||
if (i0 >= p.ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i01 = i1 % ne1;
|
||||
const uint i02 = i1 / ne1;
|
||||
|
||||
const uint idst = i1*ne0 + i0/2;
|
||||
const uint ix = rope_a_coord(i0/2, i01, i02, p);
|
||||
const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
|
||||
const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
|
||||
|
||||
const int sect_dims = p.sections[0] + p.sections[1];
|
||||
const int sec_w = p.sections[1] + p.sections[0];
|
||||
@@ -213,11 +186,11 @@ void rope_vision(const uint i0, const uint i1, rope_params p) {
|
||||
float theta_base = 0.0;
|
||||
if (sector < p.sections[0]) {
|
||||
const uint p0 = sector;
|
||||
theta_base = rope_data_pos[i02]*pow(p.theta_scale, p0);
|
||||
theta_base = rope_data_pos[i2]*pow(p.theta_scale, p0);
|
||||
}
|
||||
else if (sector >= p.sections[0] && sector < sec_w) {
|
||||
const uint p0 = sector - p.sections[0];
|
||||
theta_base = rope_data_pos[i02 + ne2]*pow(p.theta_scale, p0);
|
||||
theta_base = rope_data_pos[i2 + p.ne02]*pow(p.theta_scale, p0);
|
||||
}
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
|
||||
|
||||
@@ -5,10 +5,13 @@
|
||||
|
||||
void main() {
|
||||
const uint i0 = 2*gl_GlobalInvocationID.y;
|
||||
// i1 is actually i2*nb2+i1, but the rows are contiguous
|
||||
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
|
||||
if (i1 >= pc.nrows) {
|
||||
const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
|
||||
if (row >= pc.nrows) {
|
||||
return;
|
||||
}
|
||||
rope_multi(i0, i1, pc);
|
||||
const uint i3 = row / (pc.ne01*pc.ne02);
|
||||
const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
|
||||
const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
|
||||
|
||||
rope_multi(i0, i1, i2, i3, pc);
|
||||
}
|
||||
|
||||
@@ -5,10 +5,13 @@
|
||||
|
||||
void main() {
|
||||
const uint i0 = 2*gl_GlobalInvocationID.y;
|
||||
// i1 is actually i2*nb2+i1, but the rows are contiguous
|
||||
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
|
||||
if (i1 >= pc.nrows) {
|
||||
const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
|
||||
if (row >= pc.nrows) {
|
||||
return;
|
||||
}
|
||||
rope_neox(i0, i1, pc);
|
||||
const uint i3 = row / (pc.ne01*pc.ne02);
|
||||
const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
|
||||
const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
|
||||
|
||||
rope_neox(i0, i1, i2, i3, pc);
|
||||
}
|
||||
|
||||
@@ -5,10 +5,13 @@
|
||||
|
||||
void main() {
|
||||
const uint i0 = 2*gl_GlobalInvocationID.y;
|
||||
// i1 is actually i2*nb2+i1, but the rows are contiguous
|
||||
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
|
||||
if (i1 >= pc.nrows) {
|
||||
const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
|
||||
if (row >= pc.nrows) {
|
||||
return;
|
||||
}
|
||||
rope_norm(i0, i1, pc);
|
||||
const uint i3 = row / (pc.ne01*pc.ne02);
|
||||
const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
|
||||
const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
|
||||
|
||||
rope_norm(i0, i1, i2, i3, pc);
|
||||
}
|
||||
|
||||
@@ -5,24 +5,29 @@
|
||||
|
||||
struct rope_params {
|
||||
uint rope_mode;
|
||||
uint ncols;
|
||||
uint nrows;
|
||||
uint n_dims;
|
||||
float freq_scale;
|
||||
uint p_delta_rows;
|
||||
float freq_base;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float corr_dims[2];
|
||||
float theta_scale;
|
||||
uint has_ff;
|
||||
uint ne02;
|
||||
uint nb01;
|
||||
uint nb02;
|
||||
int sections[4];
|
||||
uint is_imrope;
|
||||
uint is_back;
|
||||
uint set_rows_stride;
|
||||
|
||||
uint ne00;
|
||||
uint ne01;
|
||||
uint ne02;
|
||||
uint nb01;
|
||||
uint nb02;
|
||||
uint nb03;
|
||||
uint nb11;
|
||||
uint nb12;
|
||||
uint nb13;
|
||||
};
|
||||
|
||||
#endif // !defined(GGML_ROPE_PARAMS)
|
||||
|
||||
@@ -5,10 +5,13 @@
|
||||
|
||||
void main() {
|
||||
const uint i0 = 2*gl_GlobalInvocationID.y;
|
||||
// i1 is actually i2*nb2+i1, but the rows are contiguous
|
||||
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
|
||||
if (i1 >= pc.nrows) {
|
||||
const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
|
||||
if (row >= pc.nrows) {
|
||||
return;
|
||||
}
|
||||
rope_vision(i0, i1, pc);
|
||||
const uint i3 = row / (pc.ne01*pc.ne02);
|
||||
const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
|
||||
const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
|
||||
|
||||
rope_vision(i0, i1, i2, i3, pc);
|
||||
}
|
||||
|
||||
2
vendor/cpp-httplib/CMakeLists.txt
vendored
2
vendor/cpp-httplib/CMakeLists.txt
vendored
@@ -39,7 +39,7 @@ if (LLAMA_BUILD_BORINGSSL)
|
||||
set(FIPS OFF CACHE BOOL "Enable FIPS (BoringSSL)")
|
||||
|
||||
set(BORINGSSL_GIT "https://boringssl.googlesource.com/boringssl" CACHE STRING "BoringSSL git repository")
|
||||
set(BORINGSSL_VERSION "0.20251002.0" CACHE STRING "BoringSSL version")
|
||||
set(BORINGSSL_VERSION "0.20260204.0" CACHE STRING "BoringSSL version")
|
||||
|
||||
message(STATUS "Fetching BoringSSL version ${BORINGSSL_VERSION}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user