Compare commits

...

95 Commits

Author SHA1 Message Date
Georgi Gerganov
624f7bd03b graph : add comments
ggml-ci
2025-02-28 21:13:08 +02:00
Georgi Gerganov
0f7daa9d1b graph : move non-context related logic to llm_build_context
ggml-ci
2025-02-28 20:36:25 +02:00
Georgi Gerganov
9cab53c7dd cont : migrate the rest of the inputs out of llama_context
ggml-ci
2025-02-28 18:01:25 +02:00
Georgi Gerganov
7f02ee562e context : decouple inputs, llama_graph_i become const (WIP)
ggml-ci
2025-02-28 16:30:41 +02:00
Georgi Gerganov
38db8a5861 llama : introduce concept of llama_memory
ggml-ci
2025-02-28 10:51:17 +02:00
Georgi Gerganov
828effd9d7 kv-cache : basic abstraction
Some checks failed
Python Type-Check / pyright type-check (push) Has been cancelled
ggml-ci
2025-02-27 16:00:29 +02:00
Georgi Gerganov
82675a0180 Merge branch 'master' into gg/llama-kv-cache
ggml-ci
2025-02-27 15:10:18 +02:00
Georgi Gerganov
952feedfca context : disable encoder embd tensor for now
ggml-ci
2025-02-27 15:07:10 +02:00
Georgi Gerganov
4efe989886 context : pass embeddings tensor from encoder to decoder
ggml-ci
2025-02-25 16:11:17 +02:00
Georgi Gerganov
e2b3294f2c context : fix enc-dec state save/load
ggml-ci
2025-02-25 12:14:34 +02:00
Georgi Gerganov
e5bc5f8e02 context : enc-dec is now working
ggml-ci
2025-02-25 12:10:34 +02:00
Georgi Gerganov
be58e30017 enc-dec : compose wip
ggml-ci
2025-02-24 18:12:24 +02:00
Georgi Gerganov
9cd78f11a1 context : explicit llama_context_i abstract interface
ggml-ci
2025-02-24 13:38:11 +02:00
Georgi Gerganov
4a1054b552 context : reuse built_attn_mha
ggml-ci
2025-02-24 11:29:52 +02:00
Georgi Gerganov
a5a85a3bc0 context : fix recurrent reserve
ggml-ci
2025-02-24 08:59:12 +02:00
Georgi Gerganov
0699a44c83 context : remove redundant virtual, protected -> private
ggml-ci
2025-02-23 20:02:11 +02:00
Georgi Gerganov
6378112cb5 graph : remove the build_kv_... API from llama_graph_i
ggml-ci
2025-02-23 19:39:22 +02:00
Georgi Gerganov
372fa3a894 cont : enc should work now, next is dec
ggml-ci
2025-02-23 12:20:23 +02:00
Georgi Gerganov
f5e80208c5 wip enc-dec 2025-02-21 19:17:47 +02:00
Georgi Gerganov
c4c0a4d13c Merge branch 'master' into gg/llama-kv-cache 2025-02-21 19:14:07 +02:00
Georgi Gerganov
3753b30d65 context : fix n_outputs init
ggml-ci
2025-02-21 15:53:26 +02:00
Georgi Gerganov
f588a70da3 context : wrap input tensors in struct
ggml-ci
2025-02-21 15:09:28 +02:00
Georgi Gerganov
ebf1bdf97b context : add logs
ggml-ci
2025-02-21 14:35:23 +02:00
Georgi Gerganov
548c230dff graph : remove worst_case from the API
ggml-ci
2025-02-21 13:29:25 +02:00
Georgi Gerganov
2645a7d9a9 context : add save/load for recurrent context
ggml-ci
2025-02-21 10:28:42 +02:00
Georgi Gerganov
08011c2ca1 context : add llama_kv_cache_recurrent prototype
ggml-ci
2025-02-20 20:55:13 +02:00
Georgi Gerganov
ad870c49f4 context : fix causal input for cache-less case
ggml-ci
2025-02-20 20:01:02 +02:00
Georgi Gerganov
b1554be1d7 context : add cache-less llama_context
ggml-ci
2025-02-20 18:30:04 +02:00
Georgi Gerganov
072280ea6b Merge branch 'master' into gg/llama-kv-cache
ggml-ci
2025-02-20 14:26:43 +02:00
Georgi Gerganov
f95b04a21c model : fix order kvq -> qkv
ggml-ci
2025-02-19 18:52:20 +02:00
Georgi Gerganov
2eacb4c1bf graph : simplify attention api
ggml-ci
2025-02-19 18:43:49 +02:00
Georgi Gerganov
e17e4b72d1 context : add llama_context_recurrent
ggml-ci
2025-02-19 16:07:27 +02:00
Georgi Gerganov
5f11a5502a kv-cache : remove llama_kv_cache_i 2025-02-19 14:36:27 +02:00
Georgi Gerganov
f5cedbcaaa kv-cache : prepare for abstraction
ggml-ci
2025-02-18 21:28:58 +02:00
Georgi Gerganov
2bffc2d514 model : pass llama_graph_i as ptr
ggml-ci
2025-02-18 14:57:26 +02:00
Georgi Gerganov
9e50456e19 context : minor simplify
Some checks failed
Python check requirements.txt / check-requirements (push) Has been cancelled
Python Type-Check / pyright type-check (push) Has been cancelled
ggml-ci
2025-02-18 14:53:02 +02:00
Georgi Gerganov
befe14f06f llama : reorder encode/decode in sources 2025-02-18 14:47:53 +02:00
Georgi Gerganov
bc6f187e9c cont : use returend tensors from the graph build
ggml-ci
2025-02-18 14:24:17 +02:00
Georgi Gerganov
172f61690c cont : return important tensors
ggml-ci
2025-02-18 13:48:43 +02:00
Georgi Gerganov
c23590319a graph : add llama_graph_result
ggml-ci
2025-02-18 13:48:21 +02:00
Georgi Gerganov
f0d3ff2388 Merge branch 'master' into gg/llama-kv-cache
ggml-ci
2025-02-18 10:14:37 +02:00
Georgi Gerganov
1d801d27b9 graph : update attn/kv_self names 2025-02-14 17:22:55 +02:00
Georgi Gerganov
828064564c context : move common inputs to base class
ggml-ci
2025-02-14 16:48:21 +02:00
Georgi Gerganov
d5e8e1a2ba context : remove batch_manager
ggml-ci
2025-02-14 16:10:55 +02:00
Georgi Gerganov
131743ff4f context : abstract constructor and init
ggml-ci
2025-02-13 17:17:51 +02:00
Georgi Gerganov
ed3cb55abe context : abstract input
ggml-ci
2025-02-13 15:53:15 +02:00
Georgi Gerganov
107d1e2c32 context : move output functionality to base class
ggml-ci
2025-02-13 15:42:14 +02:00
Georgi Gerganov
e08f38df69 context : minor cleanup
ggml-ci
2025-02-13 12:50:53 +02:00
Georgi Gerganov
f7c7757bab context : abstract state read/write
ggml-ci
2025-02-13 12:37:28 +02:00
Georgi Gerganov
3a504d9a0b llama : introduce llama_io interfaces
ggml-ci
2025-02-13 12:25:54 +02:00
Georgi Gerganov
fbe6a07256 context : rename to llama_context_kv_self 2025-02-12 17:16:44 +02:00
Georgi Gerganov
6ee86e5e0f graph : restore ubatch in build_cb
ggml-ci
2025-02-12 16:29:15 +02:00
Georgi Gerganov
f63aeecce6 llama : models now build their graphs using llama_graph_i
ggml-ci
2025-02-12 15:08:40 +02:00
Georgi Gerganov
0ab50f1bbb context : prepare llama_model graph build
ggml-ci
2025-02-12 14:09:55 +02:00
Georgi Gerganov
e633dc171a context : introduce llama_graph_i
ggml-ci
2025-02-12 13:49:44 +02:00
Georgi Gerganov
5eae8e5183 context : move build_rope_factors to base class
ggml-ci
2025-02-12 13:32:02 +02:00
Georgi Gerganov
d146a14f77 context : minor naming fix 2025-02-12 12:41:36 +02:00
Georgi Gerganov
8da7f612b7 context : improve llama_context encapsulation
ggml-ci
2025-02-12 12:15:04 +02:00
Georgi Gerganov
b52b79b048 context : move encode/decode to llama-context.cpp 2025-02-12 11:23:38 +02:00
Georgi Gerganov
02ef4be975 context : initial abstraction
ggml-ci
2025-02-11 22:27:21 +02:00
Georgi Gerganov
2cd8a903c8 context : make output functions members
Some checks failed
Python Type-Check / pyright type-check (push) Has been cancelled
ggml-ci
2025-02-10 17:01:27 +02:00
Georgi Gerganov
d1d8d53008 bman : remove ubatch member
ggml-ci
2025-02-10 16:50:14 +02:00
Georgi Gerganov
ef358ee78f context : add decode/encode
ggml-ci
2025-02-10 16:14:13 +02:00
Georgi Gerganov
879ba82777 server : increase context size for the tests
ggml-ci
2025-02-10 15:00:02 +02:00
Georgi Gerganov
f9971ef2e1 llama : dedup reserve code 2025-02-10 14:59:51 +02:00
Georgi Gerganov
972f91c7d7 Merge branch 'master' into gg/llama-kv-cache
ggml-ci
2025-02-10 14:45:54 +02:00
Georgi Gerganov
b15fede7a9 kv-cache : fix defrag condition
ggml-ci
2025-02-06 14:35:19 +02:00
Georgi Gerganov
0f1c1cab2c Merge branch 'master' into gg/llama-kv-cache
Some checks failed
Python Type-Check / pyright type-check (push) Has been cancelled
ggml-ci
2025-02-06 10:04:33 +02:00
Georgi Gerganov
e0d913fccb llama : clear whitespaces 2025-02-06 10:02:50 +02:00
Molly Sophia
1eca8916b5 llama : fix rwkv inference (#11618)
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
2025-02-03 14:17:50 +02:00
Georgi Gerganov
74b0807245 Merge branch 'master' into gg/llama-kv-cache
Some checks failed
Python check requirements.txt / check-requirements (push) Has been cancelled
Python Type-Check / pyright type-check (push) Has been cancelled
ggml-ci
2025-02-02 11:07:05 +02:00
Georgi Gerganov
3e23be7911 context : store graph build function callback
ggml-ci
2025-02-02 10:49:32 +02:00
Georgi Gerganov
5d3491e789 Merge branch 'master' into gg/llama-kv-cache
Some checks failed
Python Type-Check / pyright type-check (push) Has been cancelled
ggml-ci
2025-01-31 15:11:11 +02:00
Georgi Gerganov
a40ba49fa6 Merge branch 'master' into gg/llama-kv-cache 2025-01-30 16:39:58 +02:00
Georgi Gerganov
c30e34cdba Merge branch 'master' into gg/llama-kv-cache
Some checks failed
Python Type-Check / pyright type-check (push) Has been cancelled
ggml-ci
2025-01-29 15:01:26 +02:00
Georgi Gerganov
918885697e llama : resolve rwkv conflict
ggml-ci
2025-01-29 14:45:04 +02:00
Georgi Gerganov
e665b57fa2 Merge branch 'master' into gg/llama-kv-cache
ggml-ci
2025-01-27 14:09:22 +02:00
Georgi Gerganov
a0c500b4dc context : prepare for abstraction
Some checks failed
Python check requirements.txt / check-requirements (push) Has been cancelled
Python Type-Check / pyright type-check (push) Has been cancelled
ggml-ci
2025-01-26 20:16:22 +02:00
Georgi Gerganov
99422dfa3f context : introduce llama_batch_manager
ggml-ci
2025-01-26 20:16:22 +02:00
Georgi Gerganov
cb8f2095c6 wip 2025-01-26 20:16:22 +02:00
Georgi Gerganov
133ad6a723 context : initial need_reserve logic
ggml-ci
2025-01-26 20:16:22 +02:00
Georgi Gerganov
c75ba6851e context : move adapter code in the implementation [no ci] 2025-01-26 20:16:22 +02:00
Georgi Gerganov
f0713498fd context : add get_ctx_padding()
ggml-ci
2025-01-26 20:16:22 +02:00
Georgi Gerganov
b4ec1d4429 cont : move kv_self update to llama_context
ggml-ci
2025-01-26 20:16:21 +02:00
Georgi Gerganov
f2524c0e41 llama : remove references to llama_kv_cache (wip)
Intermediate step necessary to abstract the `llama_context` and
`llama_kv_cache`.

ggml-ci
2025-01-26 20:16:21 +02:00
Georgi Gerganov
ae274f9747 llama : fix names [no ci] 2025-01-26 20:16:21 +02:00
Georgi Gerganov
a19f671fe0 context : minor
ggml-ci
2025-01-26 20:16:21 +02:00
Georgi Gerganov
17b363afd3 llama : update llama_kv_self API
ggml-ci
2025-01-26 20:16:20 +02:00
Georgi Gerganov
fd05ab87aa kv_cache : move state read/write to llama_kv_cache
ggml-ci
2025-01-26 20:14:36 +02:00
Georgi Gerganov
4cd1b6fa4c context : prepare kv_cache_read/write to be moved to kv_cache
ggml-ci
2025-01-26 20:14:36 +02:00
Georgi Gerganov
73a14eccc9 kv_cache : minor 2025-01-26 20:14:36 +02:00
Georgi Gerganov
fef90cb3d7 kv_cache : fix
ggml-ci
2025-01-26 20:14:36 +02:00
Georgi Gerganov
4d7bd03e65 kv_cache : functions -> members
ggml-ci
2025-01-26 20:14:36 +02:00
Georgi Gerganov
e4550fbafc llama : cont
ggml-ci
2025-01-26 20:14:35 +02:00
Georgi Gerganov
f78b396ee7 llama : add struct llama_kv_cache (wip) [no ci] 2025-01-26 20:12:06 +02:00
44 changed files with 16710 additions and 11748 deletions

View File

@@ -951,8 +951,8 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams;
}
if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) {
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
params.ctx_shift = false;
}
@@ -1056,7 +1056,7 @@ struct common_init_result common_init_from_params(common_params & params) {
if (llama_model_has_decoder(model)) {
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
}
llama_kv_cache_clear(lctx);
llama_kv_self_clear(lctx);
llama_synchronize(lctx);
llama_perf_context_reset(lctx);
}

View File

@@ -172,7 +172,7 @@ llama_tokens common_speculative_gen_draft(
result.reserve(params.n_draft);
if (reuse_n == 0) {
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
prompt.clear();
} else {
@@ -191,14 +191,14 @@ llama_tokens common_speculative_gen_draft(
}
if (reuse_i > 0) {
llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
llama_kv_self_seq_rm (ctx, 0, 0, reuse_i);
llama_kv_self_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
}
if (reuse_n < (int) prompt.size()) {
llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
llama_kv_self_seq_rm (ctx, 0, reuse_n, -1);
prompt.erase(prompt.begin() + reuse_n, prompt.end());
}

View File

@@ -132,7 +132,7 @@ int main(int argc, char ** argv) {
const auto t_pp_start = ggml_time_us();
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
@@ -141,7 +141,7 @@ int main(int argc, char ** argv) {
if (is_pp_shared) {
for (int32_t i = 1; i < pl; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
}
}

View File

@@ -116,7 +116,7 @@ if llama_decode(context, batch) != 0 {
}
for i in 1 ..< n_parallel {
llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
llama_kv_self_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
}
if n_parallel > 1 {

View File

@@ -342,7 +342,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
}
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;

View File

@@ -37,7 +37,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
const struct llama_model * model = llama_get_model(ctx);
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
// run model
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

View File

@@ -45,7 +45,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
}
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
llama_set_embeddings(ctx, true);
llama_set_causal_attn(ctx, false);
@@ -102,7 +102,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
llama_token eos_token = llama_vocab_eos(vocab);
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
llama_set_embeddings(ctx, false);
llama_set_causal_attn(ctx, true);

View File

@@ -498,7 +498,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
llama_batch batch = llama_batch_init(n_batch, 0, 1);

View File

@@ -332,8 +332,8 @@ int main(int argc, char ** argv) {
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
llama_kv_self_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_self_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
n_past -= n_discard;

View File

@@ -1578,7 +1578,7 @@ int main(int argc, char ** argv) {
test t(inst, lmodel, ctx);
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
// cool off before the test
if (params.delay) {
@@ -1618,7 +1618,7 @@ int main(int argc, char ** argv) {
}
for (int i = 0; i < params.reps; i++) {
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
uint64_t t_start = get_time_ns();

View File

@@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
}
batch->logits[batch->n_tokens - 1] = true;
llama_kv_cache_clear(context);
llama_kv_self_clear(context);
const auto t_pp_start = ggml_time_us();
if (llama_decode(context, *batch) != 0) {
@@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
LOGi("Benchmark text generation (tg)");
llama_kv_cache_clear(context);
llama_kv_self_clear(context);
const auto t_tg_start = ggml_time_us();
for (i = 0; i < tg; i++) {
@@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
const auto t_tg_end = ggml_time_us();
llama_kv_cache_clear(context);
llama_kv_self_clear(context);
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
@@ -448,5 +448,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
llama_kv_self_clear(reinterpret_cast<llama_context *>(context));
}

View File

@@ -210,7 +210,7 @@ actor LlamaContext {
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
llama_kv_cache_clear(context)
llama_kv_self_clear(context)
let t_pp_start = DispatchTime.now().uptimeNanoseconds / 1000;
@@ -223,7 +223,7 @@ actor LlamaContext {
// bench text generation
llama_kv_cache_clear(context)
llama_kv_self_clear(context)
let t_tg_start = DispatchTime.now().uptimeNanoseconds / 1000;
@@ -242,7 +242,7 @@ actor LlamaContext {
let t_tg_end = DispatchTime.now().uptimeNanoseconds / 1000;
llama_kv_cache_clear(context)
llama_kv_self_clear(context)
let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
@@ -292,7 +292,7 @@ actor LlamaContext {
func clear() {
tokens_list.removeAll()
temporary_invalid_cchars.removeAll()
llama_kv_cache_clear(context)
llama_kv_self_clear(context)
}
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {

View File

@@ -95,7 +95,7 @@ int main(int argc, char ** argv) {
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
}
const auto t_enc_end = ggml_time_us();
@@ -437,17 +437,17 @@ int main(int argc, char ** argv) {
// KV cache management
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation
llama_kv_cache_seq_rm(ctx, -1, n_past, -1);
llama_kv_self_seq_rm(ctx, -1, n_past, -1);
if (seq_id_best != 0) {
// if a verification token matched, we keep the best sequence and remove the rest
// this leads to some KV cache fragmentation
llama_kv_cache_seq_keep(ctx, seq_id_best);
llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1);
llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1);
llama_kv_self_seq_keep(ctx, seq_id_best);
llama_kv_self_seq_cp (ctx, seq_id_best, 0, -1, -1);
llama_kv_self_seq_rm (ctx, seq_id_best, -1, -1);
for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
}
}
}

View File

@@ -192,7 +192,7 @@ int main(int argc, char ** argv){
// KV cache management
// clean the cache of draft tokens that weren't accepted
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
llama_kv_self_seq_rm(ctx, 0, n_past, -1);
common_batch_clear(batch_tgt);
common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);

View File

@@ -330,7 +330,7 @@ int main(int argc, char ** argv) {
}
// remove any "future" tokens that we might have inherited from the previous session
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
llama_kv_self_seq_rm(ctx, -1, n_matching_session_tokens, -1);
}
LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
@@ -573,8 +573,8 @@ int main(int argc, char ** argv) {
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
llama_kv_self_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_self_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
n_past -= n_discard;
@@ -597,9 +597,9 @@ int main(int argc, char ** argv) {
LOG_DBG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd);
llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
llama_kv_self_seq_add(ctx, 0, ga_i, n_past, ib*bd);
llama_kv_self_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
llama_kv_self_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
n_past -= bd;

View File

@@ -201,7 +201,7 @@ int main(int argc, char ** argv) {
// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i <= n_clients; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
}
LOG_INF("\n");
@@ -233,9 +233,9 @@ int main(int argc, char ** argv) {
if (batch.n_tokens == 0) {
// all sequences have ended - clear the entire KV cache
for (int i = 1; i <= n_clients; ++i) {
llama_kv_cache_seq_rm(ctx, i, -1, -1);
llama_kv_self_seq_rm(ctx, i, -1, -1);
// but keep the system prompt
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
}
LOG_INF("%s: clearing the KV cache\n", __func__);
@@ -371,8 +371,8 @@ int main(int argc, char ** argv) {
}
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
llama_kv_self_seq_rm(ctx, client.id + 1, -1, -1);
llama_kv_self_seq_cp(ctx, 0, client.id + 1, -1, -1);
const auto t_main_end = ggml_time_us();

View File

@@ -132,11 +132,11 @@ int main(int argc, char ** argv) {
const int ib = i/n_batch - 1;
const int bd = n_batch_grp*(n_grp - 1);
llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
llama_kv_cache_update (ctx);
llama_kv_self_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
llama_kv_self_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
llama_kv_self_update (ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
}
common_batch_clear(batch);
@@ -166,12 +166,12 @@ int main(int argc, char ** argv) {
LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard);
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
//llama_kv_cache_defrag (ctx);
llama_kv_cache_update (ctx);
llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
//llama_kv_self_defrag (ctx);
llama_kv_self_update (ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
common_batch_clear(batch);
@@ -197,12 +197,12 @@ int main(int argc, char ** argv) {
if (n_discard > 0) {
LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
//llama_kv_cache_defrag (ctx);
llama_kv_cache_update (ctx);
llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
//llama_kv_self_defrag (ctx);
llama_kv_self_update (ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
}
}

View File

@@ -361,7 +361,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
llama_batch batch = llama_batch_init(n_batch, 0, 1);
@@ -547,7 +547,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
@@ -924,7 +924,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
return;
}
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@@ -1203,7 +1203,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
return;
}
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@@ -1575,7 +1575,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
return;
}
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@@ -1765,7 +1765,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
}
// clear the KV cache
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
llama_batch batch = llama_batch_init(n_batch, 0, 1);

View File

@@ -83,7 +83,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
// run model
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

View File

@@ -891,7 +891,7 @@ static int apply_chat_template(const struct common_chat_templates * tmpls, Llama
// Function to tokenize the prompt
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
const bool is_first = llama_get_kv_cache_used_cells(llama_data.context.get()) == 0;
const bool is_first = llama_kv_self_used_cells(llama_data.context.get()) == 0;
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
prompt_tokens.resize(n_prompt_tokens);
@@ -907,7 +907,7 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt
// Check if we have enough space in the context to evaluate this batch
static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) {
const int n_ctx = llama_n_ctx(ctx.get());
const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get());
const int n_ctx_used = llama_kv_self_used_cells(ctx.get());
if (n_ctx_used + batch.n_tokens > n_ctx) {
printf(LOG_COL_DEFAULT "\n");
printe("context size exceeded\n");

View File

@@ -15,7 +15,7 @@ int main(int argc, char ** argv) {
return 1;
}
print_build_info();
common_init();
if (params.n_predict < 0) {
params.n_predict = 16;
@@ -196,7 +196,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
// erase whole kv
llama_kv_cache_clear(ctx3);
llama_kv_self_clear(ctx3);
fprintf(stderr, "%s : kv cache cleared\n", __func__);
// restore kv into seq 1

View File

@@ -2083,7 +2083,7 @@ struct server_context {
SRV_DBG("%s", "clearing KV cache\n");
// clear the entire KV cache
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
clean_kv_cache = false;
}
@@ -2625,8 +2625,8 @@ struct server_context {
res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
res->t_start = metrics.t_start;
res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx);
res->kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx);
res->kv_cache_tokens_count = llama_kv_self_n_tokens(ctx);
res->kv_cache_used_cells = llama_kv_self_used_cells(ctx);
res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
res->t_prompt_processing_total = metrics.t_prompt_processing_total;
@@ -2742,7 +2742,7 @@ struct server_context {
// Erase token cache
const size_t n_erased = slot->cache_tokens.size();
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
llama_kv_self_seq_rm(ctx, slot->id, -1, -1);
slot->cache_tokens.clear();
auto res = std::make_unique<server_task_result_slot_erase>();
@@ -2810,8 +2810,8 @@ struct server_context {
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
@@ -3002,8 +3002,8 @@ struct server_context {
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c);
llama_kv_cache_seq_add(ctx, slot.id, head_c, -1, kv_shift);
llama_kv_self_seq_rm (ctx, slot.id, head_p, head_c);
llama_kv_self_seq_add(ctx, slot.id, head_c, -1, kv_shift);
for (size_t i = 0; i < n_match; i++) {
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
@@ -3041,9 +3041,9 @@ struct server_context {
}
// keep only the common part
if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
if (!llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1)) {
// could not partially delete (likely using a non-Transformer model)
llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
llama_kv_self_seq_rm(ctx, slot.id, -1, -1);
// there is no common part left
slot.n_past = 0;
@@ -3283,7 +3283,7 @@ struct server_context {
slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1);
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;

View File

@@ -283,7 +283,7 @@ class ServerPreset:
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/stories260K.gguf"
server.model_alias = "tinyllama-2"
server.n_ctx = 256
server.n_ctx = 512
server.n_batch = 32
server.n_slots = 2
server.n_predict = 64

View File

@@ -98,7 +98,7 @@ int main(int argc, char ** argv) {
auto generate = [&](const std::string & prompt) {
std::string response;
const bool is_first = llama_get_kv_cache_used_cells(ctx) == 0;
const bool is_first = llama_kv_self_used_cells(ctx) == 0;
// tokenize the prompt
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
@@ -113,7 +113,7 @@ int main(int argc, char ** argv) {
while (true) {
// check if we have enough space in the context to evaluate this batch
int n_ctx = llama_n_ctx(ctx);
int n_ctx_used = llama_get_kv_cache_used_cells(ctx);
int n_ctx_used = llama_kv_self_used_cells(ctx);
if (n_ctx_used + batch.n_tokens > n_ctx) {
printf("\033[0m\n");
fprintf(stderr, "context size exceeded\n");

View File

@@ -217,7 +217,7 @@ int main(int argc, char ** argv) {
{
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
llama_kv_self_seq_rm(ctx_tgt, 0, n_past, -1);
}
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {

View File

@@ -420,14 +420,14 @@ int main(int argc, char ** argv) {
{
LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
llama_kv_cache_seq_keep(ctx_dft, s_keep);
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_dft, 0);
llama_kv_self_seq_keep(ctx_dft, s_keep);
llama_kv_self_seq_cp (ctx_dft, s_keep, 0, -1, -1);
llama_kv_self_seq_keep(ctx_dft, 0);
llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
llama_kv_cache_seq_keep(ctx_tgt, s_keep);
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_tgt, 0);
llama_kv_self_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
llama_kv_self_seq_keep(ctx_tgt, s_keep);
llama_kv_self_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_kv_self_seq_keep(ctx_tgt, 0);
}
for (int s = 0; s < n_seq_dft; ++s) {
@@ -444,7 +444,7 @@ int main(int argc, char ** argv) {
common_batch_clear(batch_dft);
common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1);
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
llama_decode(ctx_dft, batch_dft);
@@ -503,8 +503,8 @@ int main(int argc, char ** argv) {
if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) {
LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur);
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
llama_kv_self_seq_rm(ctx_dft, n_seq_cur, -1, -1);
llama_kv_self_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
// all previous tokens from this branch are now also part of the new branch
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
@@ -585,9 +585,9 @@ int main(int argc, char ** argv) {
// evaluate the target model on the drafted tokens
{
llama_kv_cache_seq_keep(ctx_tgt, 0);
llama_kv_self_seq_keep(ctx_tgt, 0);
for (int s = 1; s < n_seq_dft; ++s) {
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
llama_kv_self_seq_cp(ctx_tgt, 0, s, -1, -1);
}
// LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());

View File

@@ -60,6 +60,7 @@ extern "C" {
struct llama_model;
struct llama_context;
struct llama_sampler;
struct llama_kv_cache;
typedef int32_t llama_pos;
typedef int32_t llama_token;
@@ -467,8 +468,9 @@ extern "C" {
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); // TODO: remove const?
LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
@@ -585,7 +587,7 @@ extern "C" {
// KV cache
//
// TODO: remove llama_kv_cache_view_* API
// TODO: start using struct llama_kv_cache
// Information associated with an individual cell in the KV cache view.
struct llama_kv_cache_view_cell {
@@ -640,13 +642,19 @@ extern "C" {
// Returns the number of tokens in the KV cache (slow, use only for debug)
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
"use llama_kv_self_n_tokens instead");
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
"use llama_kv_self_used_cells instead");
// Clear the KV cache - both cell info is erased and KV data is zeroed
LLAMA_API void llama_kv_cache_clear(
LLAMA_API void llama_kv_self_clear(
struct llama_context * ctx);
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
@@ -654,73 +662,125 @@ extern "C" {
// seq_id < 0 : match any sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API bool llama_kv_cache_seq_rm(
LLAMA_API bool llama_kv_self_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1);
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1);
// Copy all tokens that belong to the specified sequence to another sequence
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_cp(
LLAMA_API void llama_kv_self_seq_cp(
struct llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1);
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1);
// Removes all tokens that do not belong to the specified sequence
LLAMA_API void llama_kv_cache_seq_keep(
LLAMA_API void llama_kv_self_seq_keep(
struct llama_context * ctx,
llama_seq_id seq_id);
llama_seq_id seq_id);
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// If the KV cache is RoPEd, the KV data is updated accordingly:
// - lazily on next llama_decode()
// - explicitly with llama_kv_cache_update()
// - explicitly with llama_kv_self_update()
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_add(
LLAMA_API void llama_kv_self_seq_add(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta);
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta);
// Integer division of the positions by factor of `d > 1`
// If the KV cache is RoPEd, the KV data is updated accordingly:
// - lazily on next llama_decode()
// - explicitly with llama_kv_cache_update()
// - explicitly with llama_kv_self_update()
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_div(
LLAMA_API void llama_kv_self_seq_div(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);
// Returns the largest position present in the KV cache for the specified sequence
LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
struct llama_context * ctx,
llama_seq_id seq_id);
// TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache
// how to avoid this?
llama_seq_id seq_id);
// Defragment the KV cache
// This will be applied:
// - lazily on next llama_decode()
// - explicitly with llama_kv_cache_update()
LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx);
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
// - explicitly with llama_kv_self_update()
LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
// Check if the context supports KV cache shifting
LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx);
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
DEPRECATED(LLAMA_API void llama_kv_cache_clear(
struct llama_context * ctx),
"use llama_kv_self_clear instead");
DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1),
"use llama_kv_self_seq_rm instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(
struct llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1),
"use llama_kv_self_seq_cp instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(
struct llama_context * ctx,
llama_seq_id seq_id),
"use llama_kv_self_seq_keep instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta),
"use llama_kv_self_seq_add instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d),
"use llama_kv_self_seq_div instead");
DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
struct llama_context * ctx,
llama_seq_id seq_id),
"use llama_kv_self_seq_pos_max instead");
DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx),
"use llama_kv_self_defrag instead");
DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx),
"use llama_kv_self_can_shift instead");
DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx),
"use llama_kv_self_update instead");
//
// State / sessions

View File

@@ -15,8 +15,10 @@ add_library(llama
llama-chat.cpp
llama-context.cpp
llama-grammar.cpp
llama-graph.cpp
llama-hparams.cpp
llama-impl.cpp
llama-io.cpp
llama-kv-cache.cpp
llama-mmap.cpp
llama-model-loader.cpp

View File

@@ -91,7 +91,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
return true;
}
int32_t llama_adapter_cvec::apply(
bool llama_adapter_cvec::apply(
const llama_model & model,
const float * data,
size_t len,
@@ -104,17 +104,17 @@ int32_t llama_adapter_cvec::apply(
// disable the current control vector (but leave allocated for later)
layer_start = -1;
layer_end = -1;
return 0;
return true;
}
if (n_embd != (int) hparams.n_embd) {
LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
return 1;
return false;
}
if (tensors.empty()) {
if (!init(model)) {
return 1;
return false;
}
}
@@ -130,7 +130,7 @@ int32_t llama_adapter_cvec::apply(
}
}
return 0;
return true;
}
// lora

View File

@@ -19,7 +19,7 @@ struct llama_adapter_cvec {
struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const;
int32_t apply(
bool apply(
const llama_model & model,
const float * data,
size_t len,

View File

@@ -42,9 +42,9 @@ struct llama_sbatch {
bool logits_all; // TODO: remove once lctx.logits_all is removed too
// sorted indices into the batch
std::vector<size_t> ids;
std::vector<int64_t> ids;
// batch indices of the output
std::vector<size_t> out_ids;
std::vector<int64_t> out_ids;
std::vector<llama_sbatch_seq> seq;
const llama_batch * batch = nullptr;

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,7 @@
#include "llama.h"
#include "llama-batch.h"
#include "llama-cparams.h"
#include "llama-graph.h"
#include "llama-model.h"
#include "llama-kv-cache.h"
#include "llama-adapter.h"
@@ -14,55 +15,435 @@
#include <vector>
#include <set>
struct llama_context {
llama_context(const llama_model & model)
: model(model)
, t_start_us(model.t_start_us)
, t_load_us(model.t_load_us) {}
class llama_io_read_i;
class llama_io_write_i;
const struct llama_model & model;
using llama_loras = std::unordered_map<struct llama_adapter_lora *, float>;
struct llama_cparams cparams;
struct llama_sbatch sbatch; // TODO: revisit if needed
struct llama_kv_cache kv_self;
struct llama_adapter_cvec cvec;
// abstract interface corresponding to the public C API
class llama_context_i {
public:
llama_context_i() = default;
virtual ~llama_context_i() = default;
std::unordered_map<struct llama_adapter_lora *, float> lora;
virtual void init() = 0;
std::vector<ggml_backend_ptr> backends;
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
virtual void synchronize() = 0;
ggml_backend_t backend_cpu = nullptr;
virtual const llama_model & get_model() const = 0;
ggml_threadpool_t threadpool = nullptr;
ggml_threadpool_t threadpool_batch = nullptr;
virtual uint32_t n_ctx() const = 0;
virtual uint32_t n_ctx_per_seq() const = 0;
virtual uint32_t n_batch() const = 0;
virtual uint32_t n_ubatch() const = 0;
virtual uint32_t n_seq_max() const = 0;
bool has_evaluated_once = false;
virtual uint32_t n_threads() const = 0;
virtual uint32_t n_threads_batch() const = 0;
mutable int64_t t_start_us;
mutable int64_t t_load_us;
// self-attention:
// if the context does not have a KV cache, return nullptr
virtual llama_kv_cache * get_kv_self() = 0;
virtual const llama_kv_cache * get_kv_self() const = 0;
// if the context does not have a KV cache, noop
virtual void kv_self_update() = 0;
virtual enum llama_pooling_type pooling_type() const = 0;
virtual float * get_logits() = 0;
virtual float * get_logits_ith(int32_t i) = 0;
virtual float * get_embeddings() = 0;
virtual float * get_embeddings_ith(int32_t i) = 0;
virtual float * get_embeddings_seq(llama_seq_id seq_id) = 0;
virtual void attach_threadpool(
ggml_threadpool_t threadpool,
ggml_threadpool_t threadpool_batch) = 0;
virtual void detach_threadpool() = 0;
virtual void set_n_threads(int32_t n_threads, int32_t n_threads_batch) = 0;
virtual void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) = 0;
virtual void set_embeddings (bool value) = 0;
virtual void set_causal_attn(bool value) = 0;
virtual void set_adapter_lora(
llama_adapter_lora * adapter,
float scale) = 0;
virtual bool rm_adapter_lora(
llama_adapter_lora * adapter) = 0;
virtual void clear_adapter_lora() = 0;
virtual bool apply_adapter_cvec(
const float * data,
size_t len,
int32_t n_embd,
int32_t il_start,
int32_t il_end) = 0;
// encode a batch of tokens by evaluating the encoder part of the transformer
//
// - lctx: llama context
// - batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
// return negative int on error
//
virtual int encode(llama_batch & inp_batch) = 0;
// decode a batch of tokens by evaluating the transformer
// in case of unsuccessful decoding (error or warning),
// the kv_cache state will be returned to its original state
// (for non-recurrent models) or cleaned (for recurrent models)
//
// - lctx: llama context
// - inp_batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
// return negative int on error
//
virtual int decode(llama_batch & inp_batch) = 0;
//
// perf
//
virtual llama_perf_context_data perf_get_data() const = 0;
virtual void perf_reset() = 0;
//
// state save/load
//
virtual size_t state_get_size() = 0;
virtual size_t state_get_data( uint8_t * dst, size_t size) = 0;
virtual size_t state_set_data(const uint8_t * src, size_t size) = 0;
virtual size_t state_seq_get_size(llama_seq_id seq_id) = 0;
virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) = 0;
virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) = 0;
virtual bool state_load_file(
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out) = 0;
virtual bool state_save_file(
const char * filepath,
const llama_token * tokens,
size_t n_token_count) = 0;
virtual size_t state_seq_load_file(
llama_seq_id seq_id,
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out) = 0;
virtual size_t state_seq_save_file(
llama_seq_id seq_id,
const char * filepath,
const llama_token * tokens,
size_t n_token_count) = 0;
};
// C alias
struct llama_context : public llama_context_i {
using llama_context_i::llama_context_i;
};
// basic transformer without KV cache
class llama_context_base : public llama_context, public llama_graph_i {
public:
llama_context_base(
const llama_model & model,
llama_context_params params,
llama_graph_type gtype);
virtual ~llama_context_base();
// init scheduler and compute buffers, reserve worst-case graphs
// call once after the context is constructed
void init() override;
void synchronize() override;
protected:
// called by init() to reserve the worst-case graphs
// override in child classes
virtual void reserve();
public:
const llama_model & get_model() const override;
uint32_t n_ctx() const override;
uint32_t n_ctx_per_seq() const override;
uint32_t n_batch() const override;
uint32_t n_ubatch() const override;
uint32_t n_seq_max() const override;
uint32_t n_threads() const override;
uint32_t n_threads_batch() const override;
llama_kv_cache * get_kv_self() override;
const llama_kv_cache * get_kv_self() const override;
void kv_self_update() override;
enum llama_pooling_type pooling_type() const override;
float * get_logits() override;
float * get_logits_ith(int32_t i) override;
float * get_embeddings() override;
float * get_embeddings_ith(int32_t i) override;
float * get_embeddings_seq(llama_seq_id seq_id) override;
void attach_threadpool(
ggml_threadpool_t threadpool,
ggml_threadpool_t threadpool_batch) override;
void detach_threadpool() override;
void set_n_threads(int32_t n_threads, int32_t n_threads_batch) override;
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) override;
void set_embeddings (bool value) override;
void set_causal_attn(bool value) override;
void set_adapter_lora(
llama_adapter_lora * adapter,
float scale) override;
bool rm_adapter_lora(
llama_adapter_lora * adapter) override;
void clear_adapter_lora() override;
bool apply_adapter_cvec(
const float * data,
size_t len,
int32_t n_embd,
int32_t il_start,
int32_t il_end) override;
int encode(llama_batch & inp_batch) override;
int decode(llama_batch & inp_batch) override;
protected:
//
// output
//
// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
virtual int32_t output_reserve(int32_t n_outputs);
// make the outputs have the same order they had in the user-provided batch
// TODO: maybe remove this
virtual void output_reorder();
//
// graph
//
virtual int32_t graph_max_nodes() const;
// zero-out inputs and create the ctx_compute for the compute graph
virtual ggml_cgraph * graph_init();
// TODO: add encode/decode graphs
virtual llama_graph_result_ptr graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch);
// returns the result of ggml_backend_sched_graph_compute_async execution
virtual enum ggml_status graph_compute(
ggml_cgraph * gf,
bool batched);
ggml_context_ptr ctx_compute;
public:
//
// graph build
//
int32_t get_n_outputs() const override;
void build_cb(
ggml_tensor * cur,
const char * name,
const llama_ubatch & ubatch,
int il) const override;
// apply control vector for layer il
ggml_tensor * build_cvec(
ggml_context * ctx0,
ggml_tensor * cur,
int il) const override;
// do mat_mul, while optionally apply lora
ggml_tensor * build_lora_mm(
ggml_context * ctx0,
ggml_tensor * w,
ggml_tensor * cur) const override;
// do mat_mul_id, while optionally apply lora
ggml_tensor * build_lora_mm_id(
ggml_context * ctx0,
ggml_tensor * w, // struct ggml_tensor * as
ggml_tensor * cur, // struct ggml_tensor * b
ggml_tensor * ids) const override;
ggml_tensor * build_rope_factors(int il) const override;
llama_graph_input_ptr build_inp_embd(
ggml_context * ctx0,
ggml_tensor * tok_embd,
const llama_ubatch & ubatch) const override;
llama_graph_input_ptr build_inp_pos_bucket(
ggml_context * ctx0,
int32_t n_tokens) const override;
llama_graph_input_attn_ptr build_attn_inp(
ggml_context * ctx0,
int32_t n_tokens,
bool causal,
bool swa) const override;
ggml_tensor * build_attn(
llama_graph_input_attn_i * inp,
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) const override;
protected:
// note: optionally set the backend to be the same as the bbuf's backend
ggml_tensor * build_rope_shift(
ggml_context * ctx0,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
ggml_backend_buffer * bbuf) const;
ggml_tensor * build_attn_mha(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
bool v_trans,
float kq_scale) const;
public:
//
// perf
//
llama_perf_context_data perf_get_data() const override;
void perf_reset() override;
protected:
// TODO: become private
mutable int64_t t_start_us = 0;
mutable int64_t t_load_us = 0;
mutable int64_t t_p_eval_us = 0;
mutable int64_t t_eval_us = 0;
mutable int64_t t_compute_start_us = 0;
mutable int64_t n_queued_tokens = 0;
mutable int64_t n_queued_tokens = 0;
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
mutable int32_t n_eval = 0; // number of eval calls
// host buffer for the model output (logits and embeddings)
ggml_backend_buffer_ptr buf_output;
public:
//
// state save/load
//
size_t state_get_size() override;
size_t state_get_data( uint8_t * dst, size_t size) override;
size_t state_set_data(const uint8_t * src, size_t size) override;
size_t state_seq_get_size(llama_seq_id seq_id) override;
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) override;
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) override;
bool state_load_file(
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out) override;
bool state_save_file(
const char * filepath,
const llama_token * tokens,
size_t n_token_count) override;
size_t state_seq_load_file(
llama_seq_id seq_id,
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out) override;
size_t state_seq_save_file(
llama_seq_id seq_id,
const char * filepath,
const llama_token * tokens,
size_t n_token_count) override;
protected:
virtual size_t state_write_data(llama_io_write_i & io);
virtual size_t state_read_data (llama_io_read_i & io);
virtual size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
virtual size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
//
// members
//
// TODO: become private / move to llama_graph_i
const llama_model & model;
llama_cparams cparams;
llama_adapter_cvec cvec;
llama_loras loras;
llama_sbatch sbatch;
ggml_backend_sched_ptr sched;
// TODO: these below likely need some rework in the future, together with the batch-refactoring
// TODO: remove
bool logits_all = false;
// decode output (2-dimensional array: [n_outputs][n_vocab])
size_t logits_size = 0; // capacity (of floats) for logits
float * logits = nullptr;
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
bool logits_all = false;
// embeddings output (2-dimensional array: [n_outputs][n_embd])
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
size_t embd_size = 0; // capacity (of floats) for embeddings
@@ -72,56 +453,421 @@ struct llama_context {
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
std::map<llama_seq_id, std::vector<float>> embd_seq;
// whether we are computing encoder output or decoder output
bool is_encoding = false;
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers
// TODO: find a better way to accommodate mutli-dimension position encoding methods
// number of position id each token get, 1 for each token in most cases.
// when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate.
int n_pos_per_token = 1;
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
// output of the encoder part of the encoder-decoder models
std::vector<float> embd_enc;
std::vector<std::set<llama_seq_id>> seq_ids_enc;
private:
// base functionality - should not leak into derived classes
// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
ggml_backend_sched_ptr sched;
ggml_threadpool_t threadpool = nullptr;
ggml_threadpool_t threadpool_batch = nullptr;
ggml_abort_callback abort_callback = nullptr;
void * abort_callback_data = nullptr;
// input tensors
struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch]
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch]
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
ggml_backend_t backend_cpu = nullptr;
std::vector<ggml_backend_ptr> backends;
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
// buffer types used for the compute buffer of each backend
std::vector<ggml_backend_t> backend_ptrs;
std::vector<ggml_backend_buffer_type_t> backend_buft;
// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
// host buffer for the model output (logits and embeddings)
ggml_backend_buffer_ptr buf_output;
bool has_evaluated_once = false;
};
// TODO: make these methods of llama_context
void llama_set_k_shift(struct llama_context & lctx);
// transformer with a self-attention KV cache
class llama_context_kv_self : public llama_context_base {
public:
llama_context_kv_self(
const llama_model & model,
llama_context_params params,
llama_graph_type gtype);
void llama_set_s_copy(struct llama_context & lctx);
virtual ~llama_context_kv_self();
void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch);
protected:
void reserve() override;
// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs);
public:
llama_kv_cache * get_kv_self() override;
const llama_kv_cache * get_kv_self() const override;
// make the outputs have the same order they had in the user-provided batch
void llama_output_reorder(struct llama_context & ctx);
void kv_self_update() override;
int encode(llama_batch & inp_batch) override;
int decode(llama_batch & inp_batch) override;
protected:
//
// graph
//
ggml_cgraph * graph_init() override;
public:
//
// graph build
//
llama_graph_input_ptr build_inp_pos_bucket(
ggml_context * ctx0,
int32_t n_tokens) const override;
llama_graph_input_attn_ptr build_attn_inp(
ggml_context * ctx0,
int32_t n_tokens,
bool causal,
bool swa) const override;
ggml_tensor * build_attn(
llama_graph_input_attn_i * inp,
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) const override;
protected:
llama_graph_result_ptr graph_build_kv_self_shift(
ggml_context * ctx0,
ggml_cgraph * gf) const;
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
llama_graph_result_ptr graph_build_kv_self_defrag(
ggml_context * ctx0,
ggml_cgraph * gf) const;
//
// state save/load
//
size_t state_write_data(llama_io_write_i & io) override;
size_t state_read_data (llama_io_read_i & io) override;
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) override;
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id) override;
private:
//
// members
//
std::unique_ptr<llama_kv_cache_unified> kv_self;
};
// a recurrent transformer (ie.e RWKV, Mamba)
class llama_context_recurrent : public llama_context_base {
public:
llama_context_recurrent(
const llama_model & model,
llama_context_params params,
llama_graph_type gtype);
virtual ~llama_context_recurrent();
protected:
void reserve() override;
public:
llama_kv_cache * get_kv_self() override;
const llama_kv_cache * get_kv_self() const override;
void kv_self_update() override;
int encode(llama_batch & inp_batch) override;
int decode(llama_batch & inp_batch) override;
protected:
//
// graph
//
ggml_cgraph * graph_init() override;
public:
//
// graph build
//
llama_graph_input_ptr build_inp_s_copy(
ggml_context * ctx0) const override;
llama_graph_input_ptr build_inp_s_mask(
ggml_context * ctx0) const override;
ggml_tensor * build_copy_mask_state(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
int32_t n_state,
int32_t n_seqs) const override;
ggml_tensor * build_mamba_layer(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const override;
ggml_tensor * build_rwkv_token_shift_load(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const override;
ggml_tensor * build_rwkv_token_shift_store(
ggml_context * ctx0,
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
int il) const override;
ggml_tensor * build_rwkv6_time_mix(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * x_prev,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const override;
protected:
//
// state save/load
//
size_t state_write_data(llama_io_write_i & io) override;
size_t state_read_data (llama_io_read_i & io) override;
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) override;
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id) override;
private:
//
// members
//
// TODO: change name to something more meaningful -- does "KV cache" make sense for recurrent models?
std::unique_ptr<llama_kv_cache_recurrent> kv_self;
};
//
// enc-dec
//
// TODO: tmp - need something better to pass the data from the encoder to the decoder
struct llama_cross {
// the output embeddings from the encoder as a ggml tensor
// TODO: this needs more work to be correct, for now copy the embeddings data to host memory
// ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
ggml_tensor * t_embd = nullptr;
// embeddings data copied to host memory (tmp)
float * v_embd = nullptr;
// needed to construct the cross-attention mask in the decoder
std::vector<std::set<llama_seq_id>> seq_ids_enc;
};
class llama_context_enc : public llama_context_base {
public:
using llama_context_base::llama_context_base;
int encode(llama_batch & inp_batch) override;
llama_cross * cross = nullptr; // TODO: hacky, rework
};
class llama_context_dec : public llama_context_kv_self {
public:
using llama_context_kv_self::llama_context_kv_self;
protected:
void reserve() override;
//
// graph
//
ggml_cgraph * graph_init() override;
llama_graph_input_ptr build_inp_cross_embd(
ggml_context * ctx0) const override;
llama_graph_input_attn_ptr build_attn_inp(
ggml_context * ctx0,
int32_t n_tokens,
bool causal,
bool swa) const override;
ggml_tensor * build_attn_cross(
llama_graph_input_attn_i * inp,
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) const override;
public:
llama_cross * cross = nullptr; // TODO: hacky, rework
};
class llama_context_enc_dec : public llama_context {
public:
llama_context_enc_dec(
const llama_model & model,
llama_context_params params);
~llama_context_enc_dec();
void init() override;
void synchronize() override;
const llama_model & get_model() const override;
// TODO: the default implementation of these getters calls the corresponding getter of the enc or dec context
// in the future, the public API in llama.h should allow to get references to the context that the user wants
// this will allow to specify the desired context explicitly
// for example:
//
// // this can be an enc-dec context
// llama_context_t ctx = llama_init_from_model(...);
//
// ...
//
// llama_context_t ctx_enc = llama_get_ctx_enc(ctx);
// llama_set_embeddings(ctx_enc, true);
//
// llama_context_t ctx_dec = llama_get_ctx_dec(ctx);
// llama_set_causal_attn(ctx_dec, true);
//
uint32_t n_ctx() const override;
uint32_t n_ctx_per_seq() const override;
uint32_t n_batch() const override;
uint32_t n_ubatch() const override;
uint32_t n_seq_max() const override;
uint32_t n_threads() const override;
uint32_t n_threads_batch() const override;
llama_kv_cache * get_kv_self() override;
const llama_kv_cache * get_kv_self() const override;
void kv_self_update() override;
enum llama_pooling_type pooling_type() const override;
float * get_logits() override;
float * get_logits_ith(int32_t i) override;
float * get_embeddings() override;
float * get_embeddings_ith(int32_t i) override;
float * get_embeddings_seq(llama_seq_id seq_id) override;
void attach_threadpool(
ggml_threadpool_t threadpool,
ggml_threadpool_t threadpool_batch) override;
void detach_threadpool() override;
void set_n_threads(int32_t n_threads, int32_t n_threads_batch) override;
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) override;
void set_embeddings (bool value) override;
void set_causal_attn(bool value) override;
void set_adapter_lora(
llama_adapter_lora * adapter,
float scale) override;
bool rm_adapter_lora(
llama_adapter_lora * adapter) override;
void clear_adapter_lora() override;
bool apply_adapter_cvec(
const float * data,
size_t len,
int32_t n_embd,
int32_t il_start,
int32_t il_end) override;
int encode(llama_batch & inp_batch) override;
int decode(llama_batch & inp_batch) override;
//
// perf
//
llama_perf_context_data perf_get_data() const override;
void perf_reset() override;
//
// state save/load
//
size_t state_get_size() override;
size_t state_get_data( uint8_t * dst, size_t size) override;
size_t state_set_data(const uint8_t * src, size_t size) override;
size_t state_seq_get_size(llama_seq_id seq_id) override;
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) override;
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) override;
bool state_load_file(
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out) override;
bool state_save_file(
const char * filepath,
const llama_token * tokens,
size_t n_token_count) override;
size_t state_seq_load_file(
llama_seq_id seq_id,
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out) override;
size_t state_seq_save_file(
llama_seq_id seq_id,
const char * filepath,
const llama_token * tokens,
size_t n_token_count) override;
private:
std::unique_ptr<llama_context_enc> ctx_enc;
std::unique_ptr<llama_context_dec> ctx_dec;
llama_cross cross;
};
// For internal test use
// TODO: remove

193
src/llama-graph.cpp Normal file
View File

@@ -0,0 +1,193 @@
#include "llama-graph.h"
#include "llama-impl.h"
ggml_tensor * llama_graph_input_attn_i::get_kq_mask() {
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr;
}
ggml_tensor * llama_graph_input_attn_i::get_kq_mask_swa() {
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr;
}
ggml_tensor * llama_graph_input_attn_i::get_kq_mask_cross() {
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr;
}
llama_graph_i::llama_graph_i(llama_graph_type type) : type(type) {}
llama_graph_input_ptr llama_graph_i::build_inp_cross_embd(
ggml_context * ctx0) const {
GGML_UNUSED(ctx0);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr;
}
ggml_tensor * llama_graph_i::build_attn(
llama_graph_input_attn_i * inp,
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) const {
GGML_UNUSED(inp);
GGML_UNUSED(ctx0);
GGML_UNUSED(gf);
GGML_UNUSED(q_cur);
GGML_UNUSED(k_cur);
GGML_UNUSED(v_cur);
GGML_UNUSED(kq_b);
GGML_UNUSED(kq_scale);
GGML_UNUSED(il);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr;
}
ggml_tensor * llama_graph_i::build_attn_cross(
llama_graph_input_attn_i * inp,
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) const {
GGML_UNUSED(inp);
GGML_UNUSED(ctx0);
GGML_UNUSED(gf);
GGML_UNUSED(q_cur);
GGML_UNUSED(k_cur);
GGML_UNUSED(v_cur);
GGML_UNUSED(kq_b);
GGML_UNUSED(kq_scale);
GGML_UNUSED(il);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr;
}
llama_graph_input_ptr llama_graph_i::build_inp_s_copy (
ggml_context * ctx0) const {
GGML_UNUSED(ctx0);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr; // NOLINT
}
llama_graph_input_ptr llama_graph_i::build_inp_s_mask(
ggml_context * ctx0) const {
GGML_UNUSED(ctx0);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr; // NOLINT
}
ggml_tensor * llama_graph_i::build_copy_mask_state(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
int32_t n_state,
int32_t n_seqs) const {
GGML_UNUSED(ctx0);
GGML_UNUSED(gf);
GGML_UNUSED(s);
GGML_UNUSED(state_copy);
GGML_UNUSED(state_mask);
GGML_UNUSED(n_state);
GGML_UNUSED(n_seqs);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr; // NOLINT
}
ggml_tensor * llama_graph_i::build_mamba_layer(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
GGML_UNUSED(ctx0);
GGML_UNUSED(gf);
GGML_UNUSED(cur);
GGML_UNUSED(state_copy);
GGML_UNUSED(state_mask);
GGML_UNUSED(ubatch);
GGML_UNUSED(il);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr; // NOLINT
}
ggml_tensor * llama_graph_i::build_rwkv_token_shift_load(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
GGML_UNUSED(ctx0);
GGML_UNUSED(gf);
GGML_UNUSED(state_copy);
GGML_UNUSED(state_mask);
GGML_UNUSED(ubatch);
GGML_UNUSED(il);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr; // NOLINT
}
ggml_tensor * llama_graph_i::build_rwkv_token_shift_store(
ggml_context * ctx0,
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
int il) const {
GGML_UNUSED(ctx0);
GGML_UNUSED(token_shift);
GGML_UNUSED(ubatch);
GGML_UNUSED(il);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr; // NOLINT
}
ggml_tensor * llama_graph_i::build_rwkv6_time_mix(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * x_prev,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
GGML_UNUSED(ctx0);
GGML_UNUSED(gf);
GGML_UNUSED(cur);
GGML_UNUSED(x_prev);
GGML_UNUSED(state_copy);
GGML_UNUSED(state_mask);
GGML_UNUSED(ubatch);
GGML_UNUSED(il);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr; // NOLINT
}

278
src/llama-graph.h Normal file
View File

@@ -0,0 +1,278 @@
#pragma once
#include <cstdint>
#include <vector>
#include <memory>
// note: do not add high-level objects here, such as llama_context, llama_kv_cache, etc.
// not sure about llama_batch/llama_sbatch yet
struct ggml_cgraph;
struct ggml_context;
struct ggml_tensor;
struct llama_ubatch;
// certain models (typically multi-modal) can produce different types of graphs
// the llama_context specifies which type of graph it needs through the llama_graph_i::type member
enum llama_graph_type {
LLAMA_GRAPH_TYPE_DEFAULT,
LLAMA_GRAPH_TYPE_ENCODER,
LLAMA_GRAPH_TYPE_DECODER,
};
//
// llama_graph_input
//
// denotes an input to the graph
// typically, the data of these objects is populated based on the contents of the current llama_ubatch:
//
// - llama_graph_input_pos
// - llama_graph_input_out_ids
// - etc.
//
// some inputs require context-specific data (e.g. KV cache) - such inputs are defined for the specific llama_context:
//
// - llama_graph_input_embd (can apply lora)
// - llama_graph_input_attn_kv_self (requires KV cache instance)
// - etc.
//
class llama_graph_input_i {
public:
virtual ~llama_graph_input_i() = default;
virtual void set_input(const llama_ubatch * ubatch) = 0;
// by default, we produce a single input tensor, but some implementations could produce more
ggml_tensor * cur = nullptr;
};
using llama_graph_input_ptr = std::shared_ptr<llama_graph_input_i>;
class llama_graph_input_attn_i : public llama_graph_input_i {
public:
virtual ~llama_graph_input_attn_i() = default;
virtual ggml_tensor * get_kq_mask();
virtual ggml_tensor * get_kq_mask_swa();
virtual ggml_tensor * get_kq_mask_cross();
};
using llama_graph_input_attn_ptr = std::shared_ptr<llama_graph_input_attn_i>;
//
// llama_graph_result
//
// these objects deliver the result from the graph build process back to the llama_context
// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
// specific data, by calling the set_inputs() method
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
// these are used by the llama_context to extact the relevant data, based on the compute parameters
class llama_graph_result_i {
public:
virtual ~llama_graph_result_i() = default;
virtual ggml_tensor * get_logits() = 0;
virtual ggml_tensor * get_embd() = 0;
virtual ggml_tensor * get_embd_pooled() = 0;
virtual void set_inputs(const llama_ubatch * ubatch) = 0;
};
using llama_graph_result_ptr = std::unique_ptr<llama_graph_result_i>;
class llama_graph_result : public llama_graph_result_i {
public:
virtual ~llama_graph_result() = default;
ggml_tensor * get_logits() override { return t_logits; }
ggml_tensor * get_embd() override { return t_embd; }
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
void set_inputs(const llama_ubatch * ubatch) override {
for (auto & input : inputs) {
input->set_input(ubatch);
}
}
void add_input(llama_graph_input_ptr input) {
inputs.emplace_back(std::move(input));
}
// important graph nodes
ggml_tensor * t_logits = nullptr;
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
std::vector<llama_graph_input_ptr> inputs;
};
//
// llama_graph
//
// this interface defines an API for building graphs by abstracting some high-level concepts such as attention, lora, etc.
// functionality that is trivial and does not rely on the llama_context should be directly implemented in llm_build_context
// other context-specific functionality should be declared here and implemented in the llama_context variations
//
// the main goal of this interface is to separate the llama_context specifics from the graph building logic
// this allows to have cleaner model architecture definitions while being able to overload certain complex
// functionality in order to fit different use cases and/or explore new implementations and ideas
// note: keep all methods const
// TODO: can become more granular in the future
class llama_graph_i {
public:
llama_graph_i(llama_graph_type type);
virtual ~llama_graph_i() = default;
llama_graph_type get_type() const {
return type;
}
private:
llama_graph_type type;
public:
virtual int32_t get_n_outputs() const = 0;
//
// context-specific API
//
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
virtual void build_cb(
ggml_tensor * cur,
const char * name,
const llama_ubatch & ubatch,
int il) const = 0;
// apply control vector for layer il
virtual ggml_tensor * build_cvec(
ggml_context * ctx0,
ggml_tensor * cur,
int il) const = 0;
// do mat_mul, while optionally apply lora
virtual ggml_tensor * build_lora_mm(
ggml_context * ctx0,
ggml_tensor * w,
ggml_tensor * cur) const = 0;
// do mat_mul_id, while optionally apply lora
virtual ggml_tensor * build_lora_mm_id(
ggml_context * ctx0,
ggml_tensor * w, // struct ggml_tensor * as
ggml_tensor * cur, // struct ggml_tensor * b
ggml_tensor * ids) const = 0;
// rope factors based on the current context size
virtual ggml_tensor * build_rope_factors(int il) const = 0;
// input embeddings with optional lora
virtual llama_graph_input_ptr build_inp_embd(
ggml_context * ctx0,
ggml_tensor * tok_embd,
const llama_ubatch & ubatch) const = 0;
// enc-dec pos
virtual llama_graph_input_ptr build_inp_pos_bucket(
ggml_context * ctx0,
int32_t n_tokens) const = 0;
virtual llama_graph_input_ptr build_inp_cross_embd(
ggml_context * ctx0) const;
//
// attention API
//
virtual llama_graph_input_attn_ptr build_attn_inp(
ggml_context * ctx0,
int32_t n_tokens,
bool causal,
bool swa) const = 0;
virtual ggml_tensor * build_attn(
llama_graph_input_attn_i * inp,
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) const;
virtual ggml_tensor * build_attn_cross(
llama_graph_input_attn_i * inp,
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) const;
//
// recurrent API
//
virtual llama_graph_input_ptr build_inp_s_copy(
ggml_context * ctx0) const;
virtual llama_graph_input_ptr build_inp_s_mask(
ggml_context * ctx0) const;
virtual ggml_tensor * build_copy_mask_state(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
int32_t n_state,
int32_t n_seqs) const;
virtual ggml_tensor * build_mamba_layer(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const;
virtual ggml_tensor * build_rwkv_token_shift_load(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const;
virtual ggml_tensor * build_rwkv_token_shift_store(
ggml_context * ctx0,
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
int il) const;
virtual ggml_tensor * build_rwkv6_time_mix(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * x_prev,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const;
};

15
src/llama-io.cpp Normal file
View File

@@ -0,0 +1,15 @@
#include "llama-io.h"
void llama_io_write_i::write_string(const std::string & str) {
uint32_t str_size = str.size();
write(&str_size, sizeof(str_size));
write(str.data(), str_size);
}
void llama_io_read_i::read_string(std::string & str) {
uint32_t str_size;
read_to(&str_size, sizeof(str_size));
str.assign((const char *) read(str_size), str_size);
}

35
src/llama-io.h Normal file
View File

@@ -0,0 +1,35 @@
#pragma once
#include <cstddef>
#include <cstdint>
#include <string>
struct ggml_tensor;
class llama_io_write_i {
public:
llama_io_write_i() = default;
virtual ~llama_io_write_i() = default;
virtual void write(const void * src, size_t size) = 0;
virtual void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) = 0;
// bytes written so far
virtual size_t n_bytes() = 0;
void write_string(const std::string & str);
};
class llama_io_read_i {
public:
llama_io_read_i() = default;
virtual ~llama_io_read_i() = default;
virtual const uint8_t * read(size_t size) = 0;
virtual void read_to(void * dst, size_t size) = 0;
// bytes read so far
virtual size_t n_bytes() = 0;
void read_string(std::string & str);
};

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +1,29 @@
#pragma once
#include "llama.h"
#include "llama-io.h"
#include "llama-memory.h"
#include "ggml-cpp.h"
#include <set>
#include <vector>
struct llama_cparams;
struct llama_hparams;
struct llama_ubatch;
struct llama_kv_cache : public llama_memory_i {
using llama_memory_i::llama_memory_i;
virtual int32_t get_n_tokens() const = 0;
virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
virtual bool get_can_shift() const = 0;
bool get_can_edit() const override { return get_can_shift(); }
};
struct llama_kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
@@ -28,11 +45,86 @@ struct llama_kv_cell {
}
};
// a structure holds information about the slot found in llama_kv_cache_find_slot
struct llama_kv_cache_slot_info {
std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
bool found = false; // the slot was found
explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
operator bool() const { return found; }
};
// ring-buffer of cached KV data
struct llama_kv_cache {
// TODO: pimpl
// TODO: add notion of max sequences
class llama_kv_cache_unified : public llama_kv_cache {
public:
llama_kv_cache_unified(const llama_hparams & hparams);
virtual ~llama_kv_cache_unified() = default;
// TODO: become constructor
bool init(
const llama_model & model, // TODO: do not reference the model
const llama_cparams & cparams,
ggml_type type_k,
ggml_type type_v,
uint32_t kv_size,
bool offload);
int32_t get_n_tokens() const override;
uint32_t get_used_cells() const override;
size_t total_size() const;
// TODO: better data structures to reduce the cost of this operation
llama_pos pos_max() const;
void clear() override;
void defrag() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_max(llama_seq_id seq_id) override;
bool get_can_shift() const override;
// find an empty slot of size "n_tokens" in the cache
// updates the cache head
// returns a structure holding information about the slot found
// Note: On success, it's important that cache.head points
// to the first cell of the slot.
llama_kv_cache_slot_info find_slot(const llama_ubatch & batch);
// TODO: maybe not needed
uint32_t get_padding(const llama_cparams & cparams) const;
// find how many cells are currently in use
uint32_t cell_max() const;
size_t size_k_bytes() const;
size_t size_v_bytes() const;
// state save/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
// members
const llama_hparams & hparams;
bool has_shift = false;
bool do_defrag = false;
// TODO: remove this and implement llama_kv_cache_recurrent instead
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
bool v_trans = true; // the value tensor is transposed
bool can_shift = false;
@@ -46,125 +138,31 @@ struct llama_kv_cache {
// computed before each graph build
uint32_t n = 0;
ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;
std::vector<llama_kv_cell> cells;
std::vector<struct ggml_tensor *> k_l; // per layer
std::vector<struct ggml_tensor *> v_l;
private:
ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
size_t total_size() const {
size_t size = 0;
for (const auto & buf : bufs) {
size += ggml_backend_buffer_get_size(buf.get());
}
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
return size;
}
// TODO: better data structures to reduce the cost of this operation
llama_pos max_pos() const {
llama_pos max_pos = -1;
for (const auto & cell : cells) {
max_pos = std::max(max_pos, cell.pos);
}
return max_pos;
}
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
// a structure holds information about the slot found in llama_kv_cache_find_slot
struct llama_kv_cache_slot_info {
std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
bool found = false; // the slot was found
explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
operator bool() const { return found; }
// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
class llama_kv_cache_recurrent : public llama_kv_cache_unified {
public:
using llama_kv_cache_unified::llama_kv_cache_unified;
};
// TODO: maybe not needed
uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams);
bool llama_kv_cache_init(
struct llama_kv_cache & cache,
const llama_model & model,
const llama_cparams & cparams,
ggml_type type_k,
ggml_type type_v,
uint32_t kv_size,
bool offload);
// find an empty slot of size "n_tokens" in the cache
// updates the cache head
// returns a structure holding information about the slot found
// Note: On success, it's important that cache.head points
// to the first cell of the slot.
struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
struct llama_kv_cache & cache,
const struct llama_ubatch & batch);
// find how many cells are currently in use
uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache);
void llama_kv_cache_clear(struct llama_kv_cache & cache);
bool llama_kv_cache_seq_rm(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1);
void llama_kv_cache_seq_cp(
struct llama_kv_cache & cache,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1);
void llama_kv_cache_seq_keep(
struct llama_kv_cache & cache,
llama_seq_id seq_id);
void llama_kv_cache_seq_add(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta);
void llama_kv_cache_seq_div(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);
llama_pos llama_kv_cache_seq_pos_max(
struct llama_kv_cache & cache,
llama_seq_id seq_id);
void llama_kv_cache_defrag(struct llama_kv_cache & cache);
int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv);
int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv);
bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv);
//
// kv cache view
//
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max);
void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv);
//
// kv cache restore
//
@@ -183,7 +181,9 @@ struct llama_kv_slot_restorer {
bool do_restore = false;
explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
llama_kv_cache_unified & cache;
explicit llama_kv_slot_restorer(llama_kv_cache_unified & cache) : cache(cache) {
old_state.head = cache.head;
old_state.n = cache.n;
}
@@ -200,19 +200,68 @@ struct llama_kv_slot_restorer {
// must be explicitly called to restore the kv_cache state
// and rollback changes from all llama_kv_cache_find_slot calls
void restore(struct llama_kv_cache & cache) {
void restore() {
if (do_restore) {
cache.head = old_state.head;
cache.n = old_state.n;
if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
llama_kv_cache_seq_rm(cache, -1, -1, -1);
cache.seq_rm(-1, -1, -1);
} else {
for (auto & slot : slot_boundaries) {
llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
cache.seq_rm(-1, slot.first, slot.second);
}
}
}
}
};
// TODO: maybe become part of the public llama_kv_cache in the future
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv);
void llama_kv_cache_clear(llama_kv_cache * kv);
bool llama_kv_cache_seq_rm(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1);
void llama_kv_cache_seq_cp(
llama_kv_cache * kv,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1);
void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id);
void llama_kv_cache_seq_add(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta);
void llama_kv_cache_seq_div(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);
llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id);
void llama_kv_cache_defrag(llama_kv_cache * kv);
bool llama_kv_cache_can_shift(const llama_kv_cache * kv);
//
// kv cache view
//
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max);
void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache * kv);

1295
src/llama-memory.cpp Normal file

File diff suppressed because it is too large Load Diff

21
src/llama-memory.h Normal file
View File

@@ -0,0 +1,21 @@
#pragma once
#include "llama.h"
// general concept of LLM memory
// the KV cache is a type of LLM memory, but there can be other types
class llama_memory_i {
public:
virtual void clear() = 0;
virtual void defrag() = 0;
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
virtual void seq_keep(llama_seq_id seq_id) = 0;
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
virtual bool get_can_edit() const = 0;
};

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,7 @@
#include "llama.h"
#include "llama-arch.h"
#include "llama-hparams.h"
#include "llama-graph.h"
#include "llama-vocab.h"
#include <memory>
@@ -10,6 +11,8 @@
#include <unordered_map>
#include <vector>
struct llama_cparams;
struct llama_ubatch;
struct llama_model_loader;
// available models
@@ -347,7 +350,7 @@ struct llama_model {
std::string desc() const;
size_t size() const;
size_t max_nodes() const;
size_t n_tensors() const;
size_t n_devices() const;
// total number of parameters in the model
@@ -362,6 +365,13 @@ struct llama_model {
const struct ggml_tensor * get_tensor(const char * name) const;
llama_graph_result_ptr build_graph(
ggml_context * ctx,
ggml_cgraph * gf,
llama_graph_i * lgf,
const llama_cparams & cparams,
const llama_ubatch & ubatch) const;
private:
struct impl;
std::unique_ptr<impl> pimpl;

File diff suppressed because it is too large Load Diff