mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-16 16:27:32 +03:00
Compare commits
29 Commits
b7983
...
gg/llama-h
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6179578988 | ||
|
|
5eb1a88dc0 | ||
|
|
6663128448 | ||
|
|
0bb1da5854 | ||
|
|
165d822044 | ||
|
|
1b74b9d73b | ||
|
|
8c68219835 | ||
|
|
7c6487b22f | ||
|
|
401c13e3c3 | ||
|
|
132143938f | ||
|
|
52b9007176 | ||
|
|
36f8e20d08 | ||
|
|
332f073589 | ||
|
|
39d0b1e8df | ||
|
|
f875d6cb72 | ||
|
|
db2bb378b1 | ||
|
|
79dac3c861 | ||
|
|
1f647b5992 | ||
|
|
eba97574da | ||
|
|
c0cfc2f78b | ||
|
|
828e5d2fcd | ||
|
|
e73690a69d | ||
|
|
e89709721b | ||
|
|
630c84a2bd | ||
|
|
df71c803b4 | ||
|
|
313a444b22 | ||
|
|
695b6b7025 | ||
|
|
f2cd962fe2 | ||
|
|
c1a581a10b |
@@ -55,6 +55,8 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
|
||||
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
|
||||
} else if (type == GGML_TYPE_F32) {
|
||||
v = *(float *) &data[i];
|
||||
} else if (type == GGML_TYPE_I64) {
|
||||
v = (float) *(int64_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I32) {
|
||||
v = (float) *(int32_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I16) {
|
||||
|
||||
@@ -235,7 +235,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
|
||||
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
|
||||
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
|
||||
llama_batch batch = llama_batch_init(n_ctx*n_clients, 0, 1);
|
||||
|
||||
int32_t n_total_prompt = 0;
|
||||
int32_t n_total_gen = 0;
|
||||
@@ -289,6 +289,7 @@ int main(int argc, char ** argv) {
|
||||
// all sequences have ended - clear the entire KV cache
|
||||
for (int i = 1; i <= n_clients; ++i) {
|
||||
llama_memory_seq_rm(mem, i, -1, -1);
|
||||
|
||||
// but keep the system prompt
|
||||
llama_memory_seq_cp(mem, 0, i, -1, -1);
|
||||
}
|
||||
|
||||
@@ -133,6 +133,7 @@ extern "C" {
|
||||
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
|
||||
|
||||
GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
|
||||
GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
|
||||
GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
|
||||
GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);
|
||||
|
||||
@@ -470,6 +470,7 @@ extern "C" {
|
||||
GGML_OP_TRANSPOSE,
|
||||
GGML_OP_GET_ROWS,
|
||||
GGML_OP_GET_ROWS_BACK,
|
||||
GGML_OP_SET_ROWS,
|
||||
GGML_OP_DIAG,
|
||||
GGML_OP_DIAG_MASK_INF,
|
||||
GGML_OP_DIAG_MASK_ZERO,
|
||||
@@ -687,6 +688,9 @@ extern "C" {
|
||||
// true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
|
||||
GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
|
||||
|
||||
// true if the elements in dimension 0 are contiguous, or there is just 1 block of elements
|
||||
GGML_API bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor);
|
||||
|
||||
GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
||||
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
||||
|
||||
@@ -1375,6 +1379,23 @@ extern "C" {
|
||||
struct ggml_tensor * b, // row indices
|
||||
struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape
|
||||
|
||||
// a TD [n_embd, ne1, ne2, ne3]
|
||||
// b TS [n_embd, n_rows, ne02, ne03] | ne02 == ne2, ne03 == ne3
|
||||
// c I64 [n_rows, ne11, ne12, 1] | c[i] in [0, ne1)
|
||||
//
|
||||
// undefined behavior if destination rows overlap
|
||||
//
|
||||
// broadcast:
|
||||
// ne2 % ne11 == 0
|
||||
// ne3 % ne12 == 0
|
||||
//
|
||||
// return view(a)
|
||||
GGML_API struct ggml_tensor * ggml_set_rows(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // destination
|
||||
struct ggml_tensor * b, // source
|
||||
struct ggml_tensor * c); // row indices
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_diag(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
@@ -192,6 +192,7 @@ typedef pthread_t ggml_thread_t;
|
||||
|
||||
static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
||||
[GGML_TYPE_F32] = {
|
||||
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_fp32,
|
||||
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
|
||||
.vec_dot_type = GGML_TYPE_F32,
|
||||
.nrows = 1,
|
||||
@@ -1814,6 +1815,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_get_rows_back(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
ggml_compute_forward_set_rows(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_DIAG:
|
||||
{
|
||||
ggml_compute_forward_diag(params, tensor);
|
||||
@@ -2167,6 +2172,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
// FIXME: get_rows can use additional threads, but the cost of launching additional threads
|
||||
// decreases performance with GPU offloading
|
||||
@@ -3121,6 +3127,10 @@ enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct g
|
||||
return ggml_graph_compute(cgraph, &cplan);
|
||||
}
|
||||
|
||||
void ggml_cpu_fp32_to_fp32(const float * x, float * y, int64_t n) {
|
||||
memcpy(y, x, n * sizeof(float));
|
||||
}
|
||||
|
||||
void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) {
|
||||
int64_t i = 0;
|
||||
#if defined(__F16C__)
|
||||
|
||||
@@ -416,6 +416,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
|
||||
|
||||
switch (op->op) {
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_SET_ROWS:
|
||||
return
|
||||
op->type != GGML_TYPE_IQ3_XXS &&
|
||||
op->type != GGML_TYPE_IQ3_S &&
|
||||
|
||||
@@ -696,24 +696,8 @@ static void ggml_compute_forward_dup_f32(
|
||||
if (ggml_is_contiguous(dst)) {
|
||||
// TODO: simplify
|
||||
if (nb00 == sizeof(float)) {
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
size_t id = 0;
|
||||
const size_t rs = ne00 * nb00;
|
||||
char * dst_ptr = (char *) dst->data;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
id += rs * ir0;
|
||||
for (int i01 = ir0; i01 < ir1; i01++) {
|
||||
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
||||
memcpy(dst_ptr + id, src0_ptr, rs);
|
||||
id += rs;
|
||||
}
|
||||
id += rs * (ne01 - ir1);
|
||||
}
|
||||
}
|
||||
} else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
|
||||
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
|
||||
if (ggml_get_type_traits_cpu(dst->type)->from_float) {
|
||||
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
|
||||
|
||||
size_t id = 0;
|
||||
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
||||
@@ -724,7 +708,7 @@ static void ggml_compute_forward_dup_f32(
|
||||
id += rs * ir0;
|
||||
for (int i01 = ir0; i01 < ir1; i01++) {
|
||||
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
quantize_row_q(src0_ptr, dst_ptr + id, ne00);
|
||||
from_float(src0_ptr, dst_ptr + id, ne00);
|
||||
id += rs;
|
||||
}
|
||||
id += rs * (ne01 - ir1);
|
||||
@@ -2282,6 +2266,52 @@ static void ggml_compute_forward_repeat_f16(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_repeat_i64(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
if (params->ith != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(ggml_can_repeat(src0, dst));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
// guaranteed to be an integer due to the check in ggml_can_repeat
|
||||
const int nr0 = (int)(ne0/ne00);
|
||||
const int nr1 = (int)(ne1/ne01);
|
||||
const int nr2 = (int)(ne2/ne02);
|
||||
const int nr3 = (int)(ne3/ne03);
|
||||
|
||||
// TODO: support for transposed / permuted tensors
|
||||
GGML_ASSERT(nb0 == sizeof(int64_t));
|
||||
GGML_ASSERT(nb00 == sizeof(int64_t));
|
||||
|
||||
// TODO: maybe this is not optimal?
|
||||
for (int i3 = 0; i3 < nr3; i3++) {
|
||||
for (int k3 = 0; k3 < ne03; k3++) {
|
||||
for (int i2 = 0; i2 < nr2; i2++) {
|
||||
for (int k2 = 0; k2 < ne02; k2++) {
|
||||
for (int i1 = 0; i1 < nr1; i1++) {
|
||||
for (int k1 = 0; k1 < ne01; k1++) {
|
||||
for (int i0 = 0; i0 < nr0; i0++) {
|
||||
int64_t * y = (int64_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
|
||||
int64_t * x = (int64_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
y[i] = x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_repeat(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
@@ -2300,6 +2330,10 @@ void ggml_compute_forward_repeat(
|
||||
{
|
||||
ggml_compute_forward_repeat_f32(params, dst);
|
||||
} break;
|
||||
case GGML_TYPE_I64:
|
||||
{
|
||||
ggml_compute_forward_repeat_i64(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -4470,6 +4504,74 @@ void ggml_compute_forward_get_rows(
|
||||
//}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_set_rows_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int64_t nc = ne00;
|
||||
const int64_t nr = ne01;
|
||||
|
||||
assert(ne0 == nc);
|
||||
assert(ne2 == ne02);
|
||||
assert(ne3 == ne03);
|
||||
assert(src0->type == GGML_TYPE_F32);
|
||||
assert(ne02 % ne11 == 0);
|
||||
assert(ne03 % ne12 == 0);
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
|
||||
|
||||
for (int64_t i03 = 0; i03 < ne03; ++i03) {
|
||||
for (int64_t i02 = 0; i02 < ne02; ++i02) {
|
||||
for (int64_t i = ir0; i < ir1; ++i) {
|
||||
const int64_t i12 = i03%ne12;
|
||||
const int64_t i11 = i02%ne11;
|
||||
const int64_t i10 = i;
|
||||
|
||||
const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
||||
|
||||
GGML_ASSERT(i1 >= 0 && i1 < ne1);
|
||||
|
||||
from_float(
|
||||
(const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
|
||||
((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_set_rows(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_set_rows_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_get_rows_back
|
||||
|
||||
static void ggml_compute_forward_get_rows_back_f32_f16(
|
||||
@@ -4751,7 +4853,8 @@ static void ggml_compute_forward_soft_max_f32(
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
//const int64_t ne11 = src1 ? src1->ne[1] : 1;
|
||||
const int64_t nb11 = src1 ? src1->nb[1] : 1;
|
||||
const int64_t nb12 = src1 ? src1->nb[2] : 1;
|
||||
|
||||
// TODO: is this supposed to be ceil instead of floor?
|
||||
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
||||
@@ -4776,6 +4879,10 @@ static void ggml_compute_forward_soft_max_f32(
|
||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
const int64_t i11 = (i1%ne01);
|
||||
//const int64_t i12 = (i1/ne01)%ne02;
|
||||
const int64_t i13 = (i1/ne01)/ne02;
|
||||
|
||||
// ALiBi
|
||||
const uint32_t h = (i1/ne01)%ne02; // head
|
||||
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
||||
@@ -4784,8 +4891,8 @@ static void ggml_compute_forward_soft_max_f32(
|
||||
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
||||
|
||||
// broadcast the mask across rows
|
||||
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
||||
float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
||||
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i13*nb12) : NULL;
|
||||
float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i13*nb12) : NULL;
|
||||
|
||||
ggml_vec_cpy_f32 (nc, wp, sp);
|
||||
ggml_vec_scale_f32(nc, wp, scale);
|
||||
@@ -7125,7 +7232,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
memset(VKQ32, 0, DV*sizeof(float));
|
||||
}
|
||||
|
||||
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
||||
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + iq3*mask->nb[2]) : NULL;
|
||||
|
||||
// k indices
|
||||
const int ik3 = iq3 / rk3;
|
||||
|
||||
@@ -53,6 +53,7 @@ void ggml_compute_forward_permute(const struct ggml_compute_params * params, str
|
||||
void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_diag(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_diag_mask_inf(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
||||
@@ -230,6 +230,7 @@ typedef struct {
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
uint64_t nb31;
|
||||
uint64_t nb32;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
float scale;
|
||||
@@ -453,6 +454,8 @@ typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
@@ -521,6 +524,22 @@ typedef struct {
|
||||
uint64_t nb2;
|
||||
} ggml_metal_kargs_get_rows;
|
||||
|
||||
typedef struct {
|
||||
int32_t nk0;
|
||||
int32_t ne01;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne11;
|
||||
int32_t ne12;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_set_rows;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
|
||||
@@ -202,6 +202,15 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
||||
GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
|
||||
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
|
||||
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
|
||||
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
|
||||
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
|
||||
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
|
||||
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
|
||||
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
||||
@@ -1166,6 +1175,15 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
||||
@@ -1630,7 +1648,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
|
||||
if (!use_bfloat) {
|
||||
for (size_t i = 0, n = 3; i < n; ++i) {
|
||||
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
|
||||
if (op->src[i] != NULL && (op->src[i]->type == GGML_TYPE_BF16 || op->type == GGML_TYPE_BF16)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1798,6 +1816,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
{
|
||||
return op->ne[3] == 1;
|
||||
}
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
if (op->src[0]->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (op->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
};
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -2523,10 +2562,7 @@ static bool ggml_metal_encode_node(
|
||||
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
|
||||
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
||||
|
||||
const int64_t nrows_x = ggml_nrows(src0);
|
||||
const int64_t nrows_y = src0->ne[1];
|
||||
|
||||
const uint32_t n_head = nrows_x/nrows_y;
|
||||
const uint32_t n_head = src0->ne[2];
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
@@ -2586,6 +2622,8 @@ static bool ggml_metal_encode_node(
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.scale =*/ scale,
|
||||
/*.max_bias =*/ max_bias,
|
||||
/*.m0 =*/ m0,
|
||||
@@ -3757,13 +3795,74 @@ static bool ggml_metal_encode_node(
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
switch (dst->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break;
|
||||
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break;
|
||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break;
|
||||
default: GGML_ABORT("not implemented");
|
||||
}
|
||||
|
||||
const int32_t nk0 = ne0/ggml_blck_size(dst->type);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
int nrptg = 1;
|
||||
if (nth > nk0) {
|
||||
nrptg = (nth + nk0 - 1)/nk0;
|
||||
nth = nk0;
|
||||
|
||||
if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||
nrptg--;
|
||||
}
|
||||
}
|
||||
|
||||
nth = MIN(nth, nk0);
|
||||
|
||||
ggml_metal_kargs_set_rows args = {
|
||||
/*.nk0 =*/ nk0,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.ne12 =*/ ne12,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
|
||||
} break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
{
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
@@ -4782,6 +4881,7 @@ static bool ggml_metal_encode_node(
|
||||
/*.nb22 =*/ nb22,
|
||||
/*.nb23 =*/ nb23,
|
||||
/*.nb31 =*/ nb31,
|
||||
/*.nb32 =*/ nb32,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.scale =*/ scale,
|
||||
|
||||
@@ -35,6 +35,17 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
|
||||
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
||||
};
|
||||
|
||||
static inline int best_index_int8(int n, constant float * val, float x) {
|
||||
if (x <= val[0]) return 0;
|
||||
if (x >= val[n-1]) return n-1;
|
||||
int ml = 0, mu = n-1;
|
||||
while (mu-ml > 1) {
|
||||
int mav = (ml+mu)/2;
|
||||
if (x < val[mav]) mu = mav; else ml = mav;
|
||||
}
|
||||
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
||||
}
|
||||
|
||||
// NOTE: this is not dequantizing - we are simply fitting the template
|
||||
template <typename type4x4>
|
||||
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
||||
@@ -97,6 +108,173 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK4_0; j++) {
|
||||
const float v = src[j];
|
||||
if (amax < fabs(v)) {
|
||||
amax = fabs(v);
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = max / -8;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst.d = d;
|
||||
|
||||
for (int j = 0; j < QK4_0/2; ++j) {
|
||||
const float x0 = src[0 + j]*id;
|
||||
const float x1 = src[QK4_0/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
||||
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
||||
|
||||
dst.qs[j] = xi0;
|
||||
dst.qs[j] |= xi1 << 4;
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
|
||||
float min = FLT_MAX;
|
||||
float max = -FLT_MAX;
|
||||
|
||||
for (int j = 0; j < QK4_1; j++) {
|
||||
const float v = src[j];
|
||||
if (min > v) min = v;
|
||||
if (max < v) max = v;
|
||||
}
|
||||
|
||||
const float d = (max - min) / ((1 << 4) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst.d = d;
|
||||
dst.m = min;
|
||||
|
||||
for (int j = 0; j < QK4_1/2; ++j) {
|
||||
const float x0 = (src[0 + j] - min)*id;
|
||||
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
||||
|
||||
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
||||
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
||||
|
||||
dst.qs[j] = xi0;
|
||||
dst.qs[j] |= xi1 << 4;
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK5_0; j++) {
|
||||
const float v = src[j];
|
||||
if (amax < fabs(v)) {
|
||||
amax = fabs(v);
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = max / -16;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst.d = d;
|
||||
|
||||
uint32_t qh = 0;
|
||||
for (int j = 0; j < QK5_0/2; ++j) {
|
||||
const float x0 = src[0 + j]*id;
|
||||
const float x1 = src[QK5_0/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
|
||||
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
|
||||
|
||||
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
||||
}
|
||||
|
||||
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
||||
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
dst.qh[j] = qh8[j];
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
|
||||
float max = src[0];
|
||||
float min = src[0];
|
||||
|
||||
for (int j = 1; j < QK5_1; j++) {
|
||||
const float v = src[j];
|
||||
min = v < min ? v : min;
|
||||
max = v > max ? v : max;
|
||||
}
|
||||
|
||||
const float d = (max - min) / 31;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst.d = d;
|
||||
dst.m = min;
|
||||
|
||||
uint32_t qh = 0;
|
||||
for (int j = 0; j < QK5_1/2; ++j) {
|
||||
const float x0 = (src[0 + j] - min)*id;
|
||||
const float x1 = (src[QK5_1/2 + j] - min)*id;
|
||||
|
||||
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
||||
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
||||
|
||||
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
||||
}
|
||||
|
||||
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
||||
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
dst.qh[j] = qh8[j];
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK4_NL; j++) {
|
||||
const float v = src[j];
|
||||
if (amax < fabs(v)) {
|
||||
amax = fabs(v);
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = max / kvalues_iq4nl_f[0];
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
float sumqx = 0, sumq2 = 0;
|
||||
for (int j = 0; j < QK4_NL/2; ++j) {
|
||||
const float x0 = src[0 + j]*id;
|
||||
const float x1 = src[QK4_NL/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
|
||||
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
|
||||
|
||||
dst.qs[j] = xi0 | (xi1 << 4);
|
||||
|
||||
const float v0 = kvalues_iq4nl_f[xi0];
|
||||
const float v1 = kvalues_iq4nl_f[xi1];
|
||||
const float w0 = src[0 + j]*src[0 + j];
|
||||
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
|
||||
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
|
||||
sumq2 += w0*v0*v0 + w1*v1*v1;
|
||||
|
||||
}
|
||||
|
||||
dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
||||
@@ -279,6 +457,26 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
|
||||
float amax = 0.0f; // absolute max
|
||||
|
||||
for (int j = 0; j < QK8_0; j++) {
|
||||
const float v = src[j];
|
||||
amax = MAX(amax, fabs(v));
|
||||
}
|
||||
|
||||
const float d = amax / ((1 << 7) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst.d = d;
|
||||
|
||||
for (int j = 0; j < QK8_0; ++j) {
|
||||
const float x0 = src[j]*id;
|
||||
|
||||
dst.qs[j] = round(x0);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
||||
const float d = xb->d;
|
||||
@@ -1065,7 +1263,7 @@ kernel void kernel_soft_max(
|
||||
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
|
||||
|
||||
device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
|
||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr;
|
||||
device const T * pmask = src1 != src0 ? (device const T *) (src1 + i01*args.nb11 + i03*args.nb12) : nullptr;
|
||||
device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
|
||||
|
||||
float slope = 1.0f;
|
||||
@@ -1161,7 +1359,7 @@ kernel void kernel_soft_max_4(
|
||||
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
|
||||
|
||||
device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
|
||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr;
|
||||
device const T * pmask = src1 != src0 ? (device const T *) (src1 + i01*args.nb11 + i03*args.nb12) : nullptr;
|
||||
device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
|
||||
|
||||
float slope = 1.0f;
|
||||
@@ -3447,7 +3645,7 @@ kernel void kernel_flash_attn_ext(
|
||||
// load the mask in shared memory
|
||||
#pragma unroll(Q)
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
|
||||
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + iq3*args.nb32);
|
||||
|
||||
const float m = pm[ic + tiisg];
|
||||
|
||||
@@ -3933,7 +4131,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
const bool has_mask = mask != q;
|
||||
|
||||
// pointer to the mask
|
||||
device const half * pm = (device const half *) (mask + iq1*args.nb31);
|
||||
device const half * pm = (device const half *) (mask + iq1*args.nb31 + iq3*args.nb32);
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
@@ -4341,6 +4539,7 @@ template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bf
|
||||
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
|
||||
#endif
|
||||
|
||||
// TODO: templetify these kernels
|
||||
kernel void kernel_cpy_f32_q8_0(
|
||||
constant ggml_metal_kargs_cpy & args,
|
||||
device const char * src0,
|
||||
@@ -4364,23 +4563,7 @@ kernel void kernel_cpy_f32_q8_0(
|
||||
for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
|
||||
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
||||
|
||||
float amax = 0.0f; // absolute max
|
||||
|
||||
for (int j = 0; j < QK8_0; j++) {
|
||||
const float v = src[j];
|
||||
amax = MAX(amax, fabs(v));
|
||||
}
|
||||
|
||||
const float d = amax / ((1 << 7) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst_data[i00/QK8_0].d = d;
|
||||
|
||||
for (int j = 0; j < QK8_0; ++j) {
|
||||
const float x0 = src[j]*id;
|
||||
|
||||
dst_data[i00/QK8_0].qs[j] = round(x0);
|
||||
}
|
||||
quantize_q8_0(src, dst_data[i00/QK8_0]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4407,32 +4590,7 @@ kernel void kernel_cpy_f32_q4_0(
|
||||
for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
|
||||
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
||||
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK4_0; j++) {
|
||||
const float v = src[j];
|
||||
if (amax < fabs(v)) {
|
||||
amax = fabs(v);
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = max / -8;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst_data[i00/QK4_0].d = d;
|
||||
|
||||
for (int j = 0; j < QK4_0/2; ++j) {
|
||||
const float x0 = src[0 + j]*id;
|
||||
const float x1 = src[QK4_0/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
||||
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
||||
|
||||
dst_data[i00/QK4_0].qs[j] = xi0;
|
||||
dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
|
||||
}
|
||||
quantize_q4_0(src, dst_data[i00/QK4_0]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4459,31 +4617,7 @@ kernel void kernel_cpy_f32_q4_1(
|
||||
for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
|
||||
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
||||
|
||||
float min = FLT_MAX;
|
||||
float max = -FLT_MAX;
|
||||
|
||||
for (int j = 0; j < QK4_1; j++) {
|
||||
const float v = src[j];
|
||||
if (min > v) min = v;
|
||||
if (max < v) max = v;
|
||||
}
|
||||
|
||||
const float d = (max - min) / ((1 << 4) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst_data[i00/QK4_1].d = d;
|
||||
dst_data[i00/QK4_1].m = min;
|
||||
|
||||
for (int j = 0; j < QK4_1/2; ++j) {
|
||||
const float x0 = (src[0 + j] - min)*id;
|
||||
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
||||
|
||||
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
||||
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
||||
|
||||
dst_data[i00/QK4_1].qs[j] = xi0;
|
||||
dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
|
||||
}
|
||||
quantize_q4_1(src, dst_data[i00/QK4_1]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4510,38 +4644,7 @@ kernel void kernel_cpy_f32_q5_0(
|
||||
for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
|
||||
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
||||
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK5_0; j++) {
|
||||
const float v = src[j];
|
||||
if (amax < fabs(v)) {
|
||||
amax = fabs(v);
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = max / -16;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst_data[i00/QK5_0].d = d;
|
||||
|
||||
uint32_t qh = 0;
|
||||
for (int j = 0; j < QK5_0/2; ++j) {
|
||||
const float x0 = src[0 + j]*id;
|
||||
const float x1 = src[QK5_0/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
|
||||
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
|
||||
|
||||
dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
||||
}
|
||||
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
dst_data[i00/QK5_0].qh[j] = qh8[j];
|
||||
}
|
||||
quantize_q5_0(src, dst_data[i00/QK5_0]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4568,51 +4671,10 @@ kernel void kernel_cpy_f32_q5_1(
|
||||
for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
|
||||
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
||||
|
||||
float max = src[0];
|
||||
float min = src[0];
|
||||
|
||||
for (int j = 1; j < QK5_1; j++) {
|
||||
const float v = src[j];
|
||||
min = v < min ? v : min;
|
||||
max = v > max ? v : max;
|
||||
}
|
||||
|
||||
const float d = (max - min) / 31;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst_data[i00/QK5_1].d = d;
|
||||
dst_data[i00/QK5_1].m = min;
|
||||
|
||||
uint32_t qh = 0;
|
||||
for (int j = 0; j < QK5_1/2; ++j) {
|
||||
const float x0 = (src[0 + j] - min)*id;
|
||||
const float x1 = (src[QK5_1/2 + j] - min)*id;
|
||||
|
||||
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
||||
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
||||
|
||||
dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
||||
}
|
||||
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
dst_data[i00/QK5_1].qh[j] = qh8[j];
|
||||
}
|
||||
quantize_q5_1(src, dst_data[i00/QK5_1]);
|
||||
}
|
||||
}
|
||||
|
||||
static inline int best_index_int8(int n, constant float * val, float x) {
|
||||
if (x <= val[0]) return 0;
|
||||
if (x >= val[n-1]) return n-1;
|
||||
int ml = 0, mu = n-1;
|
||||
while (mu-ml > 1) {
|
||||
int mav = (ml+mu)/2;
|
||||
if (x < val[mav]) mu = mav; else ml = mav;
|
||||
}
|
||||
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f32_iq4_nl(
|
||||
constant ggml_metal_kargs_cpy & args,
|
||||
device const char * src0,
|
||||
@@ -4636,40 +4698,7 @@ kernel void kernel_cpy_f32_iq4_nl(
|
||||
for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
|
||||
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
||||
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK4_NL; j++) {
|
||||
const float v = src[j];
|
||||
if (amax < fabs(v)) {
|
||||
amax = fabs(v);
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = max / kvalues_iq4nl_f[0];
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
float sumqx = 0, sumq2 = 0;
|
||||
for (int j = 0; j < QK4_NL/2; ++j) {
|
||||
const float x0 = src[0 + j]*id;
|
||||
const float x1 = src[QK4_NL/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
|
||||
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
|
||||
|
||||
dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
|
||||
|
||||
const float v0 = kvalues_iq4nl_f[xi0];
|
||||
const float v1 = kvalues_iq4nl_f[xi1];
|
||||
const float w0 = src[0 + j]*src[0 + j];
|
||||
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
|
||||
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
|
||||
sumq2 += w0*v0*v0 + w1*v1*v1;
|
||||
|
||||
}
|
||||
|
||||
dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
|
||||
quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6350,10 +6379,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
||||
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
||||
kernel void kernel_get_rows_q(
|
||||
constant ggml_metal_kargs_get_rows & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device float * dst,
|
||||
constant ggml_metal_kargs_get_rows & args,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint3 tptg [[threads_per_threadgroup]]) {
|
||||
@@ -6373,10 +6402,10 @@ kernel void kernel_get_rows_q(
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_get_rows_f(
|
||||
constant ggml_metal_kargs_get_rows & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device float * dst,
|
||||
constant ggml_metal_kargs_get_rows & args,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint3 tptg [[threads_per_threadgroup]]) {
|
||||
@@ -6394,10 +6423,10 @@ kernel void kernel_get_rows_f(
|
||||
}
|
||||
|
||||
kernel void kernel_get_rows_i32(
|
||||
constant ggml_metal_kargs_get_rows & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device int32_t * dst,
|
||||
constant ggml_metal_kargs_get_rows & args,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint3 tptg [[threads_per_threadgroup]]) {
|
||||
@@ -6414,6 +6443,67 @@ kernel void kernel_get_rows_i32(
|
||||
}
|
||||
}
|
||||
|
||||
template<typename block_q, void (*quantize_func)(device const float *, device block_q &)>
|
||||
kernel void kernel_set_rows_q32(
|
||||
constant ggml_metal_kargs_set_rows & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device float * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint3 tptg [[threads_per_threadgroup]]) {
|
||||
const int32_t i03 = tgpig.z;
|
||||
const int32_t i02 = tgpig.y;
|
||||
|
||||
const int32_t i12 = i03%args.ne12;
|
||||
const int32_t i11 = i02%args.ne11;
|
||||
|
||||
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t i10 = i01;
|
||||
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
|
||||
|
||||
device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
||||
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
||||
|
||||
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
|
||||
quantize_func(src_row + 32*ind, dst_row[ind]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_set_rows_f(
|
||||
constant ggml_metal_kargs_set_rows & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device float * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint3 tptg [[threads_per_threadgroup]]) {
|
||||
const int32_t i03 = tgpig.z;
|
||||
const int32_t i02 = tgpig.y;
|
||||
|
||||
const int32_t i12 = i03%args.ne12;
|
||||
const int32_t i11 = i02%args.ne11;
|
||||
|
||||
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t i10 = i01;
|
||||
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
|
||||
|
||||
device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
||||
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
||||
|
||||
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
|
||||
dst_row[ind] = (T) src_row[ind];
|
||||
}
|
||||
}
|
||||
|
||||
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
||||
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
||||
@@ -6837,6 +6927,27 @@ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get
|
||||
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
|
||||
//
|
||||
// set rows
|
||||
//
|
||||
|
||||
typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
|
||||
|
||||
template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f<float>;
|
||||
template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f<half>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
|
||||
#endif
|
||||
|
||||
typedef decltype(kernel_set_rows_q32<block_q8_0, quantize_q8_0>) set_rows_q32_t;
|
||||
|
||||
template [[host_name("kernel_set_rows_q8_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q8_0, quantize_q8_0>;
|
||||
template [[host_name("kernel_set_rows_q4_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_0, quantize_q4_0>;
|
||||
template [[host_name("kernel_set_rows_q4_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_1, quantize_q4_1>;
|
||||
template [[host_name("kernel_set_rows_q5_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_0, quantize_q5_0>;
|
||||
template [[host_name("kernel_set_rows_q5_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_1, quantize_q5_1>;
|
||||
template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32<block_iq4_nl, quantize_iq4_nl>;
|
||||
|
||||
//
|
||||
// matrix-matrix multiplication
|
||||
//
|
||||
|
||||
@@ -936,6 +936,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"TRANSPOSE",
|
||||
"GET_ROWS",
|
||||
"GET_ROWS_BACK",
|
||||
"SET_ROWS",
|
||||
"DIAG",
|
||||
"DIAG_MASK_INF",
|
||||
"DIAG_MASK_ZERO",
|
||||
@@ -986,7 +987,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"OPT_STEP_ADAMW",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
||||
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -1032,6 +1033,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"transpose(x)",
|
||||
"get_rows(x)",
|
||||
"get_rows_back(x)",
|
||||
"set_rows(x)",
|
||||
"diag(x)",
|
||||
"diag_mask_inf(x)",
|
||||
"diag_mask_zero(x)",
|
||||
@@ -1082,7 +1084,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"adamw(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
||||
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -1351,6 +1353,12 @@ bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
|
||||
tensor->nb[2] == ggml_type_size(tensor->type);
|
||||
}
|
||||
|
||||
bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor) {
|
||||
return
|
||||
tensor->ne[0] == ggml_blck_size(tensor->type) ||
|
||||
tensor->nb[0] == ggml_type_size(tensor->type);
|
||||
}
|
||||
|
||||
static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
|
||||
@@ -3395,6 +3403,35 @@ struct ggml_tensor * ggml_get_rows_back(
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_set_rows
|
||||
|
||||
struct ggml_tensor * ggml_set_rows(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c) {
|
||||
GGML_ASSERT(a->ne[0] == b->ne[0]);
|
||||
GGML_ASSERT(a->ne[2] == b->ne[2]);
|
||||
GGML_ASSERT(a->ne[3] == b->ne[3]);
|
||||
GGML_ASSERT(b->ne[1] == c->ne[0]);
|
||||
GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
|
||||
GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
|
||||
GGML_ASSERT(c->ne[3] == 1);
|
||||
GGML_ASSERT(b->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(c->type == GGML_TYPE_I64);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(a));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(b));
|
||||
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_SET_ROWS;
|
||||
result->src[0] = b;
|
||||
result->src[1] = c;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_diag
|
||||
|
||||
struct ggml_tensor * ggml_diag(
|
||||
@@ -3489,7 +3526,7 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
||||
if (mask) {
|
||||
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||
GGML_ASSERT(ggml_is_matrix(mask));
|
||||
GGML_ASSERT(ggml_is_3d(mask));
|
||||
GGML_ASSERT(mask->ne[0] == a->ne[0]);
|
||||
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
||||
}
|
||||
@@ -4467,7 +4504,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||
|
||||
if (mask) {
|
||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||
GGML_ASSERT(mask->ne[2] == 1);
|
||||
GGML_ASSERT(mask->ne[2] == q->ne[3]);
|
||||
GGML_ASSERT(mask->ne[3] == 1);
|
||||
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
|
||||
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
|
||||
|
||||
@@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
|
||||
|
||||
// note: tracking the other way around is not necessary for now
|
||||
//seq_cpl[s0][s1] = true;
|
||||
|
||||
has_cpl = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -403,6 +405,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
|
||||
return n_outputs;
|
||||
}
|
||||
|
||||
uint32_t llama_batch_allocr::get_n_used() const {
|
||||
return n_used;
|
||||
}
|
||||
|
||||
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
|
||||
return out_ids;
|
||||
}
|
||||
@@ -418,6 +424,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
|
||||
void llama_batch_allocr::split_reset() {
|
||||
out_ids.clear();
|
||||
|
||||
n_used = 0;
|
||||
|
||||
used.clear();
|
||||
used.resize(get_n_tokens(), false);
|
||||
|
||||
@@ -442,6 +450,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
||||
idxs.push_back(cur_idx);
|
||||
|
||||
used[cur_idx] = true;
|
||||
++n_used;
|
||||
|
||||
++cur_idx;
|
||||
|
||||
@@ -457,9 +466,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
||||
return ubatch_add(idxs, idxs.size(), false);
|
||||
}
|
||||
|
||||
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
||||
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
|
||||
if (sequential && has_cpl) {
|
||||
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<seq_set_t> cur_seq_set;
|
||||
|
||||
llama_seq_id last_seq_id = -1;
|
||||
|
||||
// determine the non-overlapping sequence sets participating in this ubatch
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
if (used[i]) {
|
||||
@@ -476,9 +493,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
// accept only increasing sequence ids
|
||||
if (sequential) {
|
||||
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
|
||||
}
|
||||
|
||||
if (add) {
|
||||
cur_seq_set.push_back(seq_set[i]);
|
||||
|
||||
last_seq_id = batch.seq_id[i][0];
|
||||
|
||||
if (cur_seq_set.size() > n_ubatch) {
|
||||
break;
|
||||
}
|
||||
@@ -527,6 +551,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
||||
idxs_per_seq[s].push_back(idx);
|
||||
|
||||
used[idx] = true;
|
||||
++n_used;
|
||||
|
||||
++cur_idx[s];
|
||||
}
|
||||
@@ -568,6 +593,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
|
||||
idxs.push_back(cur_idx);
|
||||
|
||||
used[cur_idx] = true;
|
||||
++n_used;
|
||||
|
||||
if (idxs.size() >= n_ubatch) {
|
||||
break;
|
||||
|
||||
@@ -54,6 +54,7 @@ public:
|
||||
|
||||
uint32_t get_n_tokens() const;
|
||||
uint32_t get_n_outputs() const;
|
||||
uint32_t get_n_used() const;
|
||||
|
||||
// the array of output indices in the order they were encountered during the ubatch splitting
|
||||
std::vector<int32_t> & get_out_ids();
|
||||
@@ -69,7 +70,8 @@ public:
|
||||
llama_ubatch split_simple(uint32_t n_ubatch);
|
||||
|
||||
// make ubatches of equal-length sequences sets
|
||||
llama_ubatch split_equal(uint32_t n_ubatch);
|
||||
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
|
||||
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
|
||||
|
||||
// sequence-set-wise split - each ubatch contains a single sequence-set
|
||||
llama_ubatch split_seq(uint32_t n_ubatch);
|
||||
@@ -112,6 +114,9 @@ private:
|
||||
using pos_set_t = std::set<llama_pos>;
|
||||
using seq_cpl_t = std::vector<bool>;
|
||||
|
||||
// helper flag to quickly determine if there are any coupled sequences in the batch
|
||||
bool has_cpl;
|
||||
|
||||
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
||||
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
||||
|
||||
@@ -125,6 +130,8 @@ private:
|
||||
// batch indices of the output
|
||||
std::vector<int32_t> out_ids;
|
||||
|
||||
uint32_t n_used;
|
||||
|
||||
// used[i] indicates if token i has already been used in a previous ubatch
|
||||
std::vector<bool> used;
|
||||
|
||||
|
||||
@@ -33,6 +33,9 @@ llama_context::llama_context(
|
||||
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
|
||||
}
|
||||
|
||||
const char * LLAMA_HT = getenv("LLAMA_HT");
|
||||
cparams.n_seq_virt = LLAMA_HT ? cparams.n_seq_max : 1;
|
||||
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch;
|
||||
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
||||
|
||||
@@ -11,8 +11,9 @@ struct llama_cparams {
|
||||
uint32_t n_batch;
|
||||
uint32_t n_ubatch;
|
||||
uint32_t n_seq_max;
|
||||
int n_threads; // number of threads to use for generation
|
||||
int n_threads_batch; // number of threads to use for batch processing
|
||||
uint32_t n_seq_virt;
|
||||
int32_t n_threads; // number of threads to use for generation
|
||||
int32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
float rope_freq_base;
|
||||
float rope_freq_scale;
|
||||
|
||||
@@ -281,12 +281,36 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_k_idxs) {
|
||||
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
||||
}
|
||||
|
||||
if (self_v_idxs) {
|
||||
mctx->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
}
|
||||
|
||||
if (self_kq_mask) {
|
||||
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_k_idxs) {
|
||||
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
||||
}
|
||||
|
||||
if (self_v_idxs) {
|
||||
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
}
|
||||
|
||||
if (self_k_idxs_swa) {
|
||||
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
|
||||
}
|
||||
|
||||
if (self_v_idxs_swa) {
|
||||
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
||||
}
|
||||
|
||||
if (self_kq_mask) {
|
||||
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
@@ -989,10 +1013,10 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
||||
|
||||
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
||||
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
||||
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
@@ -1019,6 +1043,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
float kq_scale) const {
|
||||
const bool v_trans = v->nb[1] > v->nb[2];
|
||||
|
||||
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
|
||||
|
||||
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_seqs, n_seqs);
|
||||
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
||||
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
||||
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
||||
@@ -1067,7 +1095,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
#endif
|
||||
}
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs);
|
||||
} else {
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
|
||||
@@ -1112,7 +1140,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
|
||||
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs);
|
||||
|
||||
if (!cparams.offload_kqv) {
|
||||
// all nodes between the KV store and the attention output are run on the CPU
|
||||
@@ -1190,10 +1218,13 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
||||
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
@@ -1224,8 +1255,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
||||
const auto & k_idxs = inp->get_k_idxs();
|
||||
const auto & v_idxs = inp->get_v_idxs();
|
||||
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
@@ -1278,8 +1312,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
||||
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
|
||||
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
|
||||
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||
@@ -1383,8 +1420,8 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, nullptr, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, nullptr, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
@@ -1416,11 +1453,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
||||
|
||||
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
|
||||
|
||||
{
|
||||
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
@@ -1431,8 +1472,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
||||
|
||||
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
|
||||
@@ -248,10 +248,16 @@ public:
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
||||
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
@@ -273,13 +279,23 @@ public:
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
||||
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
||||
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
|
||||
ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
|
||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
||||
@@ -20,14 +20,15 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
||||
bool swa_full,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_seq_virt,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad) : hparams(model.hparams) {
|
||||
uint32_t n_pad) : hparams(model.hparams), n_seq_virt(n_seq_virt) {
|
||||
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
||||
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
||||
|
||||
const uint32_t size_base = kv_size;
|
||||
|
||||
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
|
||||
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(n_seq_max/n_seq_virt) + n_ubatch, n_pad));
|
||||
|
||||
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
||||
if (swa_full) {
|
||||
@@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
||||
|
||||
kv_base = std::make_unique<llama_kv_cache_unified>(
|
||||
model, std::move(filter_base), type_k, type_v,
|
||||
v_trans, offload, size_base, n_seq_max, n_pad,
|
||||
v_trans, offload, size_base, n_seq_max, n_seq_virt, n_pad,
|
||||
0, LLAMA_SWA_TYPE_NONE);
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
||||
|
||||
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
||||
model, std::move(filter_swa), type_k, type_v,
|
||||
v_trans, offload, size_swa, n_seq_max, n_pad,
|
||||
v_trans, offload, size_swa, n_seq_max, n_seq_virt, n_pad,
|
||||
hparams.n_swa, hparams.swa_type);
|
||||
}
|
||||
|
||||
@@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
||||
|
||||
// first try simple split
|
||||
do {
|
||||
if (n_seq_virt > 1) {
|
||||
// requires equal splits, so we skip the simple split
|
||||
break;
|
||||
}
|
||||
|
||||
balloc.split_reset();
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
@@ -113,20 +119,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
auto heads_base = kv_base->prepare(ubatches);
|
||||
if (heads_base.empty()) {
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
auto heads_swa = kv_swa->prepare(ubatches);
|
||||
if (heads_swa.empty()) {
|
||||
auto sinfos_base = kv_base->prepare(ubatches);
|
||||
if (sinfos_base.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
assert(heads_base.size() == heads_swa.size());
|
||||
auto sinfos_swa = kv_swa->prepare(ubatches);
|
||||
if (sinfos_swa.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
assert(sinfos_base.size() == sinfos_swa.size());
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
||||
} while (false);
|
||||
|
||||
// if it fails, try equal split
|
||||
@@ -135,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (true) {
|
||||
auto ubatch = balloc.split_equal(n_ubatch);
|
||||
auto ubatch = balloc.split_equal(n_ubatch, n_seq_virt > 1);
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
@@ -144,20 +155,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
auto heads_base = kv_base->prepare(ubatches);
|
||||
if (heads_base.empty()) {
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
auto heads_swa = kv_swa->prepare(ubatches);
|
||||
if (heads_swa.empty()) {
|
||||
auto sinfos_base = kv_base->prepare(ubatches);
|
||||
if (sinfos_base.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
assert(heads_base.size() == heads_swa.size());
|
||||
auto sinfos_swa = kv_swa->prepare(ubatches);
|
||||
if (sinfos_swa.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
assert(sinfos_base.size() == sinfos_swa.size());
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
||||
} while (false);
|
||||
|
||||
// TODO: if we fail again, we should attempt different splitting strategies
|
||||
@@ -220,13 +236,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||
|
||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
std::vector<uint32_t> heads_base,
|
||||
std::vector<uint32_t> heads_swa,
|
||||
slot_info_vec_t sinfos_base,
|
||||
slot_info_vec_t sinfos_swa,
|
||||
std::vector<llama_ubatch> ubatches) :
|
||||
ubatches(std::move(ubatches)),
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
|
||||
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
|
||||
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
|
||||
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
|
||||
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ public:
|
||||
bool swa_full,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_seq_virt,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad);
|
||||
|
||||
@@ -68,12 +69,16 @@ public:
|
||||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
const uint32_t n_seq_virt = 1;
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
||||
};
|
||||
|
||||
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
|
||||
public:
|
||||
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||
|
||||
// used for errors
|
||||
llama_kv_cache_unified_iswa_context(llama_memory_status status);
|
||||
|
||||
@@ -90,8 +95,8 @@ public:
|
||||
// used to create a batch processing context from a batch
|
||||
llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
std::vector<uint32_t> heads_base,
|
||||
std::vector<uint32_t> heads_swa,
|
||||
slot_info_vec_t sinfos_base,
|
||||
slot_info_vec_t sinfos_swa,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_unified_iswa_context();
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -24,8 +24,6 @@ public:
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
using ubatch_heads = std::vector<uint32_t>;
|
||||
|
||||
struct defrag_info {
|
||||
bool empty() const {
|
||||
return ids.empty();
|
||||
@@ -37,6 +35,48 @@ public:
|
||||
std::vector<uint32_t> ids;
|
||||
};
|
||||
|
||||
struct slot_info {
|
||||
// data for ggml_set_rows
|
||||
using idx_vec_t = std::vector<uint32_t>;
|
||||
|
||||
llama_seq_id s0;
|
||||
llama_seq_id s1;
|
||||
|
||||
std::vector<llama_seq_id> seq_id_virt;
|
||||
std::vector<idx_vec_t> idxs;
|
||||
|
||||
uint32_t head() const {
|
||||
GGML_ASSERT(idxs.size() == 1);
|
||||
|
||||
return idxs[0][0];
|
||||
}
|
||||
|
||||
void resize(size_t n) {
|
||||
seq_id_virt.resize(n);
|
||||
idxs.resize(n);
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
GGML_ASSERT(idxs.size() == seq_id_virt.size());
|
||||
|
||||
return idxs[0].size();
|
||||
}
|
||||
|
||||
size_t n_seq_virt() const {
|
||||
return seq_id_virt.size();
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
return idxs.empty();
|
||||
}
|
||||
|
||||
void clear() {
|
||||
idxs.clear();
|
||||
}
|
||||
};
|
||||
|
||||
using slot_info_vec_t = std::vector<slot_info>;
|
||||
|
||||
llama_kv_cache_unified(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
@@ -46,6 +86,7 @@ public:
|
||||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_seq_virt,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type);
|
||||
@@ -98,36 +139,44 @@ public:
|
||||
uint32_t get_n_kv() const;
|
||||
|
||||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
|
||||
|
||||
//
|
||||
// preparation API
|
||||
//
|
||||
|
||||
// find places for the provided ubatches in the cache, returns the head locations
|
||||
// find places for the provided ubatches in the cache, returns the slot infos
|
||||
// return empty vector on failure
|
||||
ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
||||
|
||||
// return the cell position where we can insert the ubatch
|
||||
// return -1 on failure to find a contiguous slot of kv cells
|
||||
int32_t find_slot(const llama_ubatch & ubatch) const;
|
||||
// find a slot of kv cells that can hold the ubatch
|
||||
// if cont == true, then the slot must be continuous
|
||||
// return empty slot_info on failure
|
||||
slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
|
||||
|
||||
// emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
|
||||
void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);
|
||||
// emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
|
||||
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
|
||||
|
||||
//
|
||||
// set_input API
|
||||
// input API
|
||||
//
|
||||
|
||||
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
|
||||
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||
|
||||
void set_input_k_shift(ggml_tensor * dst) const;
|
||||
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_k_shift (ggml_tensor * dst) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
private:
|
||||
@@ -141,15 +190,15 @@ private:
|
||||
|
||||
ggml_tensor * k;
|
||||
ggml_tensor * v;
|
||||
|
||||
std::vector<ggml_tensor *> k_seq;
|
||||
std::vector<ggml_tensor *> v_seq;
|
||||
};
|
||||
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
uint32_t head = 0;
|
||||
|
||||
const uint32_t n_seq_max = 1;
|
||||
const uint32_t n_seq_max = 1;
|
||||
const uint32_t n_seq_virt = 1;
|
||||
|
||||
// required padding
|
||||
const uint32_t n_pad = 1;
|
||||
@@ -157,14 +206,26 @@ private:
|
||||
// SWA
|
||||
const uint32_t n_swa = 0;
|
||||
|
||||
// env: LLAMA_KV_CACHE_DEBUG
|
||||
int debug = 0;
|
||||
|
||||
// env: LLAMA_SET_ROWS (temporary)
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
|
||||
int supports_set_rows = false;
|
||||
|
||||
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
llama_kv_cells_unified cells;
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
std::vector<uint32_t> v_heads;
|
||||
|
||||
std::vector<llama_kv_cells_unified> v_cells;
|
||||
|
||||
// maps from a sequence id to a virtual sequence id
|
||||
std::vector<uint32_t> seq_virt_idx;
|
||||
|
||||
std::vector<kv_layer> layers;
|
||||
|
||||
@@ -211,8 +272,8 @@ private:
|
||||
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
||||
public:
|
||||
// some shorthands
|
||||
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||
|
||||
// used for errors
|
||||
llama_kv_cache_unified_context(llama_memory_status status);
|
||||
@@ -231,7 +292,7 @@ public:
|
||||
// used to create a batch procesing context from a batch
|
||||
llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv,
|
||||
ubatch_heads heads,
|
||||
slot_info_vec_t sinfos,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_unified_context();
|
||||
@@ -257,8 +318,14 @@ public:
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
|
||||
|
||||
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
|
||||
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
void set_input_k_shift(ggml_tensor * dst) const;
|
||||
|
||||
@@ -283,10 +350,10 @@ private:
|
||||
// batch processing context
|
||||
//
|
||||
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
// the index of the cur ubatch to process
|
||||
size_t i_cur = 0;
|
||||
|
||||
ubatch_heads heads;
|
||||
slot_info_vec_t sinfos;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
@@ -297,7 +364,4 @@ private:
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// as the cache gets filled, the benefit from this heuristic disappears
|
||||
int32_t n_kv;
|
||||
|
||||
// the beginning of the current slot in which the ubatch will be inserted
|
||||
int32_t head;
|
||||
};
|
||||
|
||||
@@ -105,10 +105,29 @@ public:
|
||||
res.resize(n);
|
||||
|
||||
for (uint32_t j = 0; j < n; ++j) {
|
||||
res.pos[j] = pos[i + j];
|
||||
res.seq[j] = seq[i + j];
|
||||
const auto idx = i + j;
|
||||
|
||||
assert(shift[i + j] == 0);
|
||||
res.pos[j] = pos[idx];
|
||||
res.seq[j] = seq[idx];
|
||||
|
||||
assert(shift[idx] == 0);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
|
||||
llama_kv_cells_unified res;
|
||||
|
||||
res.resize(idxs.size());
|
||||
|
||||
for (uint32_t j = 0; j < idxs.size(); ++j) {
|
||||
const auto idx = idxs[j];
|
||||
|
||||
res.pos[j] = pos[idx];
|
||||
res.seq[j] = seq[idx];
|
||||
|
||||
assert(shift[idx] == 0);
|
||||
}
|
||||
|
||||
return res;
|
||||
@@ -119,26 +138,57 @@ public:
|
||||
assert(i + other.pos.size() <= pos.size());
|
||||
|
||||
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
||||
if (pos[i + j] == -1 && other.pos[j] != -1) {
|
||||
const auto idx = i + j;
|
||||
|
||||
if (pos[idx] == -1 && other.pos[j] != -1) {
|
||||
used.insert(i + j);
|
||||
}
|
||||
|
||||
if (pos[i + j] != -1 && other.pos[j] == -1) {
|
||||
if (pos[idx] != -1 && other.pos[j] == -1) {
|
||||
used.erase(i + j);
|
||||
}
|
||||
|
||||
if (pos[i + j] != -1) {
|
||||
if (pos[idx] != -1) {
|
||||
seq_pos_rm(i + j);
|
||||
}
|
||||
|
||||
pos[i + j] = other.pos[j];
|
||||
seq[i + j] = other.seq[j];
|
||||
pos[idx] = other.pos[j];
|
||||
seq[idx] = other.seq[j];
|
||||
|
||||
if (pos[i + j] != -1) {
|
||||
if (pos[idx] != -1) {
|
||||
seq_pos_add(i + j);
|
||||
}
|
||||
|
||||
assert(shift[i + j] == 0);
|
||||
assert(shift[idx] == 0);
|
||||
}
|
||||
}
|
||||
|
||||
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
|
||||
assert(idxs.size() == other.pos.size());
|
||||
|
||||
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
||||
const auto idx = idxs[j];
|
||||
|
||||
if (pos[idx] == -1 && other.pos[j] != -1) {
|
||||
used.insert(idx);
|
||||
}
|
||||
|
||||
if (pos[idx] != -1 && other.pos[j] == -1) {
|
||||
used.erase(idx);
|
||||
}
|
||||
|
||||
if (pos[idx] != -1) {
|
||||
seq_pos_rm(idx);
|
||||
}
|
||||
|
||||
pos[idx] = other.pos[j];
|
||||
seq[idx] = other.seq[j];
|
||||
|
||||
if (pos[idx] != -1) {
|
||||
seq_pos_add(idx);
|
||||
}
|
||||
|
||||
assert(shift[idx] == 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ llama_memory_hybrid::llama_memory_hybrid(
|
||||
offload,
|
||||
kv_size,
|
||||
n_seq_max,
|
||||
1,
|
||||
n_pad,
|
||||
n_swa,
|
||||
swa_type
|
||||
@@ -70,7 +71,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
|
||||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
ubatch = balloc.split_equal(n_ubatch);
|
||||
ubatch = balloc.split_equal(n_ubatch, false);
|
||||
}
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
@@ -80,6 +81,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
// prepare the recurrent batches first
|
||||
if (!mem_recr->prepare(ubatches)) {
|
||||
// TODO: will the recurrent cache be in an undefined context at this point?
|
||||
@@ -195,11 +201,11 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
|
||||
|
||||
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
||||
llama_memory_hybrid * mem,
|
||||
std::vector<uint32_t> heads_attn,
|
||||
slot_info_vec_t sinfos_attn,
|
||||
std::vector<llama_ubatch> ubatches) :
|
||||
ubatches(std::move(ubatches)),
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
|
||||
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
|
||||
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||
}
|
||||
|
||||
@@ -92,6 +92,8 @@ private:
|
||||
|
||||
class llama_memory_hybrid_context : public llama_memory_context_i {
|
||||
public:
|
||||
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||
|
||||
// init failure
|
||||
explicit llama_memory_hybrid_context(llama_memory_status status);
|
||||
|
||||
@@ -107,7 +109,7 @@ public:
|
||||
// init success
|
||||
llama_memory_hybrid_context(
|
||||
llama_memory_hybrid * mem,
|
||||
std::vector<uint32_t> heads_attn,
|
||||
slot_info_vec_t sinfos_attn,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
~llama_memory_hybrid_context() = default;
|
||||
|
||||
@@ -365,26 +365,35 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
||||
llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
while (true) {
|
||||
llama_ubatch ubatch;
|
||||
do {
|
||||
balloc.split_reset();
|
||||
|
||||
if (embd_all) {
|
||||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
ubatch = balloc.split_equal(n_ubatch);
|
||||
while (true) {
|
||||
llama_ubatch ubatch;
|
||||
|
||||
if (embd_all) {
|
||||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
ubatch = balloc.split_equal(n_ubatch, false);
|
||||
}
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
if (!prepare(ubatches)) {
|
||||
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
if (!prepare(ubatches)) {
|
||||
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
} while (false);
|
||||
|
||||
return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
|
||||
}
|
||||
|
||||
@@ -13814,6 +13814,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
params.swa_full,
|
||||
cparams.n_ctx,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_seq_virt,
|
||||
cparams.n_ubatch,
|
||||
padding);
|
||||
} else {
|
||||
@@ -13828,6 +13829,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
cparams.offload_kqv,
|
||||
cparams.n_ctx,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_seq_virt,
|
||||
padding,
|
||||
hparams.n_swa,
|
||||
hparams.swa_type);
|
||||
|
||||
@@ -1213,6 +1213,78 @@ struct test_get_rows_back : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_SET_ROWS
|
||||
struct test_set_rows : public test_case {
|
||||
const ggml_type type;
|
||||
const int n; // cols
|
||||
const int m; // rows
|
||||
const int r; // rows to set
|
||||
const int b0; // batch size
|
||||
const int b1; // batch size
|
||||
const int bs; // batch size src (for testing broadcast)
|
||||
const bool v; // view (non-contiguous src1)
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR7(type, n, m, r, b0, bs, v);
|
||||
}
|
||||
|
||||
test_set_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, int bs = 1, bool v = false)
|
||||
: type(type), n(n), m(m), r(r), b0(b), b1(3), bs(bs), v(v) {
|
||||
GGML_ASSERT(b0 % bs == 0 && "b0 must be a multiple of bs");
|
||||
GGML_ASSERT(r <= m && "r must be less than or equal to m");
|
||||
}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, n, m, b0, b1);
|
||||
ggml_set_name(dst, "dst");
|
||||
|
||||
ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n, r, b0, b1);
|
||||
ggml_set_name(src, "src");
|
||||
|
||||
ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, r, bs, b1);
|
||||
ggml_set_name(row_idxs, "row_idxs");
|
||||
|
||||
if (v) {
|
||||
src = ggml_view_4d(ctx, src, n, r/2, b0, b1, src->nb[1], src->nb[2], src->nb[3], 0);
|
||||
row_idxs = ggml_view_3d(ctx, row_idxs, r/2, bs, b1, row_idxs->nb[1], row_idxs->nb[2], 0);
|
||||
ggml_set_name(row_idxs, "view_of_rows");
|
||||
}
|
||||
|
||||
ggml_tensor * out = ggml_set_rows(ctx, dst, src, row_idxs);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
std::random_device rd;
|
||||
std::default_random_engine rng(rd());
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I64) {
|
||||
if (ggml_is_view_op(t->op)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int i2 = 0; i2 < t->ne[2]; i2++) {
|
||||
for (int i1 = 0; i1 < t->ne[1]; i1++) {
|
||||
std::vector<int64_t> data(m);
|
||||
for (int i = 0; i < m; i++) {
|
||||
data[i] = i;
|
||||
}
|
||||
std::shuffle(data.begin(), data.end(), rng);
|
||||
data.resize(t->ne[0]);
|
||||
|
||||
const size_t offs = i1*t->nb[1] + i2*t->nb[2];
|
||||
ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_ARGMAX
|
||||
struct test_argmax : public test_case {
|
||||
const ggml_type type;
|
||||
@@ -3984,6 +4056,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_I32, 256, 5, 4, 1, v));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, 1, 8, 2, 1, 1, false));
|
||||
for (ggml_type type : all_types) {
|
||||
for (int b : {1, 7}) {
|
||||
for (bool v : {false, true}) {
|
||||
test_cases.emplace_back(new test_set_rows(type, 256, 5, 4, b, 1, v));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (ggml_type type_input : {GGML_TYPE_F32}) {
|
||||
for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
|
||||
for (int k0 : {1, 3}) {
|
||||
|
||||
@@ -61,7 +61,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const int32_t n_kv_max = llama_n_ctx(ctx);
|
||||
|
||||
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
|
||||
llama_batch batch = llama_batch_init(n_kv_max*8, 0, 1); // TODO: tmp!!!
|
||||
|
||||
// decode in batches of ctx_params.n_batch tokens
|
||||
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
|
||||
@@ -119,9 +119,9 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
|
||||
|
||||
if (n_ctx_req > n_kv_max) {
|
||||
continue;
|
||||
}
|
||||
//if (n_ctx_req > n_kv_max) {
|
||||
// continue;
|
||||
//}
|
||||
|
||||
common_batch_clear(batch);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user