mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-05 13:53:23 +02:00
vulkan: fix non-contig rope (#19299)
This commit is contained in:
@@ -1263,25 +1263,30 @@ struct vk_op_diag_mask_push_constants {
|
||||
|
||||
struct vk_op_rope_push_constants {
|
||||
uint32_t rope_mode;
|
||||
uint32_t ncols;
|
||||
uint32_t nrows;
|
||||
uint32_t n_dims;
|
||||
float freq_scale;
|
||||
uint32_t p_delta_rows;
|
||||
float freq_base;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float corr_dims[2];
|
||||
float theta_scale;
|
||||
uint32_t has_ff;
|
||||
uint32_t ne02;
|
||||
uint32_t s1;
|
||||
uint32_t s2;
|
||||
int32_t sections[4];
|
||||
uint32_t is_imrope;
|
||||
uint32_t is_back;
|
||||
uint32_t set_rows_stride;
|
||||
uint32_t ne00;
|
||||
uint32_t ne01;
|
||||
uint32_t ne02;
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
};
|
||||
static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128");
|
||||
|
||||
// For fused rms_norm+mul+rope(+view+set_rows)
|
||||
struct vk_op_rms_norm_mul_rope_push_constants {
|
||||
@@ -10405,12 +10410,22 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *
|
||||
|
||||
uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type);
|
||||
uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
|
||||
uint32_t nb03 = src0->nb[3] / ggml_type_size(src0->type);
|
||||
|
||||
uint32_t nb11 = dst->nb[1] / ggml_type_size(dst->type);
|
||||
uint32_t nb12 = dst->nb[2] / ggml_type_size(dst->type);
|
||||
uint32_t nb13 = dst->nb[3] / ggml_type_size(dst->type);
|
||||
|
||||
vk_op_rope_push_constants rope {
|
||||
(uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
|
||||
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
|
||||
has_ff, (uint32_t)src0->ne[2], nb01, nb02,
|
||||
(uint32_t)mode, (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, has_ff,
|
||||
{ sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
|
||||
|
||||
(uint32_t)src0->ne[0],
|
||||
(uint32_t)src0->ne[1],
|
||||
(uint32_t)src0->ne[2],
|
||||
nb01, nb02, nb03,
|
||||
nb11, nb12, nb13,
|
||||
};
|
||||
|
||||
return rope;
|
||||
@@ -14798,6 +14813,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ROPE:
|
||||
return ggml_is_contiguous_rows(op) && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_ROPE_BACK:
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
|
||||
@@ -112,12 +112,11 @@ void rms_norm(uint num_iters) {
|
||||
#if RMS_NORM_ROPE_FUSION
|
||||
barrier();
|
||||
rope_params rp = p.rope;
|
||||
uint rope_row = (samp*nchannels + channel)*nrows + row;
|
||||
for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) {
|
||||
if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) {
|
||||
rope_neox(t, rope_row, rp);
|
||||
rope_neox(t, row, channel, samp, rp);
|
||||
} else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) {
|
||||
rope_norm(t, rope_row, rp);
|
||||
rope_norm(t, row, channel, samp, rp);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -4,12 +4,12 @@ float rope_yarn_ramp(const float low, const float high, const uint i0) {
|
||||
return 1.0f - min(1.0f, max(0.0f, y));
|
||||
}
|
||||
|
||||
uint rope_a_coord(const uint i0, const uint i01, const uint i02, rope_params p) {
|
||||
uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03, rope_params p) {
|
||||
#if RMS_NORM_ROPE_FUSION
|
||||
// Per-row offset in shared memory
|
||||
const uint ix = i0;
|
||||
#else
|
||||
const uint ix = i02*p.nb02 + i01*p.nb01 + i0;
|
||||
const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0;
|
||||
#endif
|
||||
return ix;
|
||||
}
|
||||
@@ -34,26 +34,19 @@ void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out
|
||||
sin_theta = sin(theta) * mscale;
|
||||
}
|
||||
|
||||
void rope_norm(const uint i0, const uint i1, rope_params p) {
|
||||
uint ne0 = p.ncols;
|
||||
uint ne1 = p.p_delta_rows;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
|
||||
if (i0 >= p.ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
// i1 is actually i2*nb2+i1, but the rows are contiguous
|
||||
const uint i01 = i1 % ne1;
|
||||
const uint i02 = i1 / ne1;
|
||||
|
||||
uint idst = i1*ne0 + i0;
|
||||
const uint ix = rope_a_coord(i0, i01, i02, p);
|
||||
uint idst = i0 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
|
||||
const uint ix = rope_a_coord(i0, i1, i2, i3, p);
|
||||
|
||||
// Fusion optimization: ROPE + VIEW + SET_ROWS.
|
||||
// The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
|
||||
if (p.set_rows_stride != 0) {
|
||||
idst = i01*ne0 + i0;
|
||||
idst += rope_data_i[i02].x * p.set_rows_stride;
|
||||
idst = i1*p.nb11 + i0;
|
||||
idst += rope_data_i[i2].x * p.set_rows_stride;
|
||||
}
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
@@ -63,7 +56,7 @@ void rope_norm(const uint i0, const uint i1, rope_params p) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
|
||||
const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
|
||||
|
||||
@@ -77,25 +70,19 @@ void rope_norm(const uint i0, const uint i1, rope_params p) {
|
||||
rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
|
||||
void rope_neox(const uint i0, const uint i1, rope_params p) {
|
||||
uint ne0 = p.ncols;
|
||||
uint ne1 = p.p_delta_rows;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
|
||||
if (i0 >= p.ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i01 = i1 % ne1;
|
||||
const uint i02 = i1 / ne1;
|
||||
|
||||
uint idst = i1*ne0 + i0/2;
|
||||
const uint ix = rope_a_coord(i0/2, i01, i02, p);
|
||||
uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
|
||||
const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
|
||||
|
||||
// Fusion optimization: ROPE + VIEW + SET_ROWS.
|
||||
// The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
|
||||
if (p.set_rows_stride != 0) {
|
||||
idst = i01*ne0 + i0/2;
|
||||
idst += rope_data_i[i02].x * p.set_rows_stride;
|
||||
idst = i1*p.nb11 + i0/2;
|
||||
idst += rope_data_i[i2].x * p.set_rows_stride;
|
||||
}
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
@@ -105,7 +92,7 @@ void rope_neox(const uint i0, const uint i1, rope_params p) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
|
||||
const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
|
||||
|
||||
@@ -120,26 +107,19 @@ void rope_neox(const uint i0, const uint i1, rope_params p) {
|
||||
}
|
||||
|
||||
|
||||
void rope_multi(const uint i0, const uint i1, rope_params p) {
|
||||
uint ne0 = p.ncols;
|
||||
uint ne1 = p.p_delta_rows;
|
||||
uint ne2 = p.ne02;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
|
||||
if (i0 >= p.ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i01 = i1 % ne1;
|
||||
const uint i02 = i1 / ne1;
|
||||
|
||||
uint idst = i1*ne0 + i0/2;
|
||||
const uint ix = rope_a_coord(i0/2, i01, i02, p);
|
||||
uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
|
||||
const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
|
||||
|
||||
// Fusion optimization: ROPE + VIEW + SET_ROWS.
|
||||
// The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
|
||||
if (p.set_rows_stride != 0) {
|
||||
idst = i01*ne0 + i0/2;
|
||||
idst += rope_data_i[i02].x * p.set_rows_stride;
|
||||
idst = i1*p.nb11 + i0/2;
|
||||
idst += rope_data_i[i2].x * p.set_rows_stride;
|
||||
}
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
@@ -156,26 +136,26 @@ void rope_multi(const uint i0, const uint i1, rope_params p) {
|
||||
float theta_base = 0.0;
|
||||
if (p.is_imrope != 0) {
|
||||
if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
|
||||
theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||
theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
|
||||
theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||
theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
|
||||
theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
|
||||
theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
|
||||
} else {
|
||||
theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||
theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
} else {
|
||||
if (sector < p.sections[0]) {
|
||||
theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
|
||||
theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= p.sections[0] && sector < sec_w) {
|
||||
theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||
theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
|
||||
theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||
theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w + p.sections[2]) {
|
||||
theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||
theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,20 +171,13 @@ void rope_multi(const uint i0, const uint i1, rope_params p) {
|
||||
rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
|
||||
void rope_vision(const uint i0, const uint i1, rope_params p) {
|
||||
uint ne0 = p.ncols;
|
||||
uint ne1 = p.p_delta_rows;
|
||||
uint ne2 = p.ne02;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
|
||||
if (i0 >= p.ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i01 = i1 % ne1;
|
||||
const uint i02 = i1 / ne1;
|
||||
|
||||
const uint idst = i1*ne0 + i0/2;
|
||||
const uint ix = rope_a_coord(i0/2, i01, i02, p);
|
||||
const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
|
||||
const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
|
||||
|
||||
const int sect_dims = p.sections[0] + p.sections[1];
|
||||
const int sec_w = p.sections[1] + p.sections[0];
|
||||
@@ -213,11 +186,11 @@ void rope_vision(const uint i0, const uint i1, rope_params p) {
|
||||
float theta_base = 0.0;
|
||||
if (sector < p.sections[0]) {
|
||||
const uint p0 = sector;
|
||||
theta_base = rope_data_pos[i02]*pow(p.theta_scale, p0);
|
||||
theta_base = rope_data_pos[i2]*pow(p.theta_scale, p0);
|
||||
}
|
||||
else if (sector >= p.sections[0] && sector < sec_w) {
|
||||
const uint p0 = sector - p.sections[0];
|
||||
theta_base = rope_data_pos[i02 + ne2]*pow(p.theta_scale, p0);
|
||||
theta_base = rope_data_pos[i2 + p.ne02]*pow(p.theta_scale, p0);
|
||||
}
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
|
||||
|
||||
@@ -5,10 +5,13 @@
|
||||
|
||||
void main() {
|
||||
const uint i0 = 2*gl_GlobalInvocationID.y;
|
||||
// i1 is actually i2*nb2+i1, but the rows are contiguous
|
||||
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
|
||||
if (i1 >= pc.nrows) {
|
||||
const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
|
||||
if (row >= pc.nrows) {
|
||||
return;
|
||||
}
|
||||
rope_multi(i0, i1, pc);
|
||||
const uint i3 = row / (pc.ne01*pc.ne02);
|
||||
const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
|
||||
const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
|
||||
|
||||
rope_multi(i0, i1, i2, i3, pc);
|
||||
}
|
||||
|
||||
@@ -5,10 +5,13 @@
|
||||
|
||||
void main() {
|
||||
const uint i0 = 2*gl_GlobalInvocationID.y;
|
||||
// i1 is actually i2*nb2+i1, but the rows are contiguous
|
||||
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
|
||||
if (i1 >= pc.nrows) {
|
||||
const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
|
||||
if (row >= pc.nrows) {
|
||||
return;
|
||||
}
|
||||
rope_neox(i0, i1, pc);
|
||||
const uint i3 = row / (pc.ne01*pc.ne02);
|
||||
const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
|
||||
const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
|
||||
|
||||
rope_neox(i0, i1, i2, i3, pc);
|
||||
}
|
||||
|
||||
@@ -5,10 +5,13 @@
|
||||
|
||||
void main() {
|
||||
const uint i0 = 2*gl_GlobalInvocationID.y;
|
||||
// i1 is actually i2*nb2+i1, but the rows are contiguous
|
||||
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
|
||||
if (i1 >= pc.nrows) {
|
||||
const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
|
||||
if (row >= pc.nrows) {
|
||||
return;
|
||||
}
|
||||
rope_norm(i0, i1, pc);
|
||||
const uint i3 = row / (pc.ne01*pc.ne02);
|
||||
const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
|
||||
const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
|
||||
|
||||
rope_norm(i0, i1, i2, i3, pc);
|
||||
}
|
||||
|
||||
@@ -5,24 +5,29 @@
|
||||
|
||||
struct rope_params {
|
||||
uint rope_mode;
|
||||
uint ncols;
|
||||
uint nrows;
|
||||
uint n_dims;
|
||||
float freq_scale;
|
||||
uint p_delta_rows;
|
||||
float freq_base;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float corr_dims[2];
|
||||
float theta_scale;
|
||||
uint has_ff;
|
||||
uint ne02;
|
||||
uint nb01;
|
||||
uint nb02;
|
||||
int sections[4];
|
||||
uint is_imrope;
|
||||
uint is_back;
|
||||
uint set_rows_stride;
|
||||
|
||||
uint ne00;
|
||||
uint ne01;
|
||||
uint ne02;
|
||||
uint nb01;
|
||||
uint nb02;
|
||||
uint nb03;
|
||||
uint nb11;
|
||||
uint nb12;
|
||||
uint nb13;
|
||||
};
|
||||
|
||||
#endif // !defined(GGML_ROPE_PARAMS)
|
||||
|
||||
@@ -5,10 +5,13 @@
|
||||
|
||||
void main() {
|
||||
const uint i0 = 2*gl_GlobalInvocationID.y;
|
||||
// i1 is actually i2*nb2+i1, but the rows are contiguous
|
||||
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
|
||||
if (i1 >= pc.nrows) {
|
||||
const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
|
||||
if (row >= pc.nrows) {
|
||||
return;
|
||||
}
|
||||
rope_vision(i0, i1, pc);
|
||||
const uint i3 = row / (pc.ne01*pc.ne02);
|
||||
const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
|
||||
const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
|
||||
|
||||
rope_vision(i0, i1, i2, i3, pc);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user