mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-19 14:13:22 +02:00
Compare commits
11 Commits
b8061
...
gg/server-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4a2751258a | ||
|
|
cc5cafecf4 | ||
|
|
aef22e7afc | ||
|
|
9ceb268ee1 | ||
|
|
a4854f0349 | ||
|
|
f2d988db55 | ||
|
|
91fd50be1b | ||
|
|
439c3b5021 | ||
|
|
59dda88aae | ||
|
|
d7c27d4964 | ||
|
|
a9d7bcb7fc |
@@ -79,6 +79,8 @@ struct server_slot {
|
||||
|
||||
common_speculative * spec = nullptr;
|
||||
|
||||
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
|
||||
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
|
||||
std::unique_ptr<const server_task> task;
|
||||
std::unique_ptr<const server_task> task_prev; // used for debugging
|
||||
|
||||
@@ -153,7 +155,7 @@ struct server_slot {
|
||||
|
||||
common_sampler_ptr smpl;
|
||||
|
||||
llama_token sampled; // in speculative mode, this is the last accepted token
|
||||
llama_token sampled; // in speculative mode, this is the last accepted token
|
||||
llama_tokens drafted;
|
||||
|
||||
// stats
|
||||
@@ -201,12 +203,46 @@ struct server_slot {
|
||||
alora_invocation_start = -1;
|
||||
}
|
||||
|
||||
// remove cached prompt + tokens
|
||||
void clear(bool allow_processing) {
|
||||
if (!allow_processing) {
|
||||
GGML_ASSERT(!is_processing());
|
||||
}
|
||||
|
||||
SLT_INF(*this, "clearing slot with %zu tokens\n", prompt.tokens.size());
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
|
||||
prompt.tokens.clear();
|
||||
}
|
||||
|
||||
void init_sampler() const {
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
common_sampler_reset(smpl.get());
|
||||
|
||||
int n_text = 0;
|
||||
|
||||
for (int i = 0; i < (int) prompt.tokens.size(); i++) {
|
||||
const llama_token id = prompt.tokens[i];
|
||||
|
||||
if (id != LLAMA_TOKEN_NULL) {
|
||||
common_sampler_accept(smpl.get(), id, false);
|
||||
n_text++;
|
||||
}
|
||||
}
|
||||
|
||||
SLT_INF(*this, "init sampler, took %0.2f ms, tokens: text = %d, total = %d\n",
|
||||
(ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size());
|
||||
}
|
||||
|
||||
// TODO: move to server_task
|
||||
bool need_embd() const {
|
||||
GGML_ASSERT(task);
|
||||
|
||||
return server_task_type_need_embd(task->type);
|
||||
}
|
||||
|
||||
// TODO: move to server_task
|
||||
bool need_logits() const {
|
||||
GGML_ASSERT(task);
|
||||
|
||||
@@ -258,10 +294,13 @@ struct server_slot {
|
||||
SLT_WRN(*this, "%s", "slot is not processing\n");
|
||||
return;
|
||||
}
|
||||
|
||||
generated_token_probs.push_back(token);
|
||||
}
|
||||
|
||||
int get_n_draft_max() const {
|
||||
GGML_ASSERT(task);
|
||||
|
||||
if (!can_speculate()) {
|
||||
return 0;
|
||||
}
|
||||
@@ -287,12 +326,14 @@ struct server_slot {
|
||||
}
|
||||
|
||||
// note: a slot can also be either a parent or a child
|
||||
// TODO: move to server_task
|
||||
bool is_parent() const {
|
||||
return is_processing() && task->n_children > 0;
|
||||
return task->n_children > 0;
|
||||
}
|
||||
|
||||
// TODO: move to server_task
|
||||
bool is_child() const {
|
||||
return is_processing() && task->id_parent >= 0;
|
||||
return task->id_parent >= 0;
|
||||
}
|
||||
|
||||
void release() {
|
||||
@@ -301,10 +342,16 @@ struct server_slot {
|
||||
|
||||
SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated);
|
||||
|
||||
t_last_used = ggml_time_us();
|
||||
t_last_used = ggml_time_us();
|
||||
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
|
||||
|
||||
state = SLOT_STATE_IDLE;
|
||||
|
||||
// do not keep context of the child slots - the parent's context is enough
|
||||
if (is_child()) {
|
||||
clear(false);
|
||||
}
|
||||
|
||||
task_prev = std::move(task);
|
||||
task.reset();
|
||||
|
||||
@@ -425,14 +472,22 @@ struct server_slot {
|
||||
}
|
||||
|
||||
void copy_state_to(server_slot & other) const {
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), other.id, 0, -1);
|
||||
llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, 0, -1);
|
||||
GGML_ASSERT(state == SLOT_STATE_DONE_PROMPT);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), other.id, -1, -1);
|
||||
llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, -1, -1);
|
||||
|
||||
other.n_decoded = n_decoded;
|
||||
other.n_remaining = n_remaining;
|
||||
other.i_batch = i_batch;
|
||||
|
||||
other.t_start_process_prompt = t_start_process_prompt;
|
||||
other.t_prompt_processing = t_prompt_processing;
|
||||
other.n_prompt_tokens_cache = n_prompt_tokens_cache;
|
||||
other.n_prompt_tokens_processed = n_prompt_tokens_processed;
|
||||
|
||||
other.prompt = prompt.clone();
|
||||
other.init_sampler();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -745,6 +800,7 @@ private:
|
||||
}
|
||||
|
||||
slots.clear();
|
||||
|
||||
for (int i = 0; i < params_base.n_parallel; i++) {
|
||||
server_slot slot;
|
||||
|
||||
@@ -993,7 +1049,7 @@ private:
|
||||
ret->prompt_save(*prompt_cache);
|
||||
|
||||
if (!ret->prompt_load(*prompt_cache, task.tokens)) {
|
||||
clear_slot(*ret);
|
||||
ret->clear(false);
|
||||
}
|
||||
|
||||
prompt_cache->update();
|
||||
@@ -1005,17 +1061,6 @@ private:
|
||||
return ret;
|
||||
}
|
||||
|
||||
void clear_slot(server_slot & slot, bool allow_processing = false) const {
|
||||
if (!allow_processing) {
|
||||
GGML_ASSERT(!slot.is_processing());
|
||||
}
|
||||
|
||||
SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size());
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||
slot.prompt.tokens.clear();
|
||||
}
|
||||
|
||||
// return true if at least one slot has been cleared
|
||||
// TODO: improve logic
|
||||
// - smarter decision which slot to clear (LRU or longest prompt?)
|
||||
@@ -1036,7 +1081,7 @@ private:
|
||||
if (slot.prompt.n_tokens() > 0) {
|
||||
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
|
||||
|
||||
clear_slot(slot);
|
||||
slot.clear(false);
|
||||
|
||||
res = true;
|
||||
|
||||
@@ -1182,7 +1227,7 @@ private:
|
||||
? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
|
||||
: SLOT_STATE_STARTED;
|
||||
|
||||
SLT_INF(slot, "%s", "processing task\n");
|
||||
SLT_INF(slot, "processing task, is_child = %d\n", slot.is_child());
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -1819,7 +1864,7 @@ private:
|
||||
// Erase token cache
|
||||
const size_t n_erased = slot->prompt.tokens.size();
|
||||
|
||||
clear_slot(*slot);
|
||||
slot->clear(false);
|
||||
|
||||
auto res = std::make_unique<server_task_result_slot_erase>();
|
||||
res->id = task.id;
|
||||
@@ -2053,293 +2098,317 @@ private:
|
||||
continue;
|
||||
}
|
||||
|
||||
// this slot still has a prompt to be processed
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
||||
const auto & input_tokens = slot.task->tokens;
|
||||
// check if this is a child slot
|
||||
if (slot.state == SLOT_STATE_WAIT_OTHER) {
|
||||
SLT_DBG(slot, "%s", "waiting for parent slot to complete\n");
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: maybe move branch to outside of this loop in the future
|
||||
if (slot.state == SLOT_STATE_STARTED) {
|
||||
slot.t_start_process_prompt = ggml_time_us();
|
||||
slot.t_start_generation = 0;
|
||||
|
||||
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
||||
|
||||
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
|
||||
slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
|
||||
|
||||
// print prompt tokens (for debugging)
|
||||
/*if (1) {
|
||||
// first 16 tokens (avoid flooding logs)
|
||||
for (int i = 0; i < std::min<int>(16, input_tokens.size()); i++) {
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
|
||||
// SLOT_STATE_STARTED -> SLOT_STATE_PROCESSING_PROMPT
|
||||
// TODO: maybe move branch to outside of this loop in the future
|
||||
if (slot.state == SLOT_STATE_STARTED) {
|
||||
// wait for all children to be launched
|
||||
if (slot.is_parent()) {
|
||||
int n_launched = 0;
|
||||
for (auto & other : slots) {
|
||||
if (other.is_processing() && other.is_child() && other.task->id_parent == slot.task->id) {
|
||||
++n_launched;
|
||||
}
|
||||
} else {
|
||||
// all
|
||||
for (int i = 0; i < (int) input_tokens.size(); i++) {
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
|
||||
}
|
||||
}*/
|
||||
|
||||
// keep track how many tokens we can reuse from the previous state
|
||||
int n_past = 0;
|
||||
|
||||
// empty prompt passed -> release the slot and send empty response
|
||||
if (input_tokens.empty()) {
|
||||
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
|
||||
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
slot.release();
|
||||
}
|
||||
|
||||
if (n_launched < slot.task->n_children) {
|
||||
SLT_DBG(slot, "waiting for children to be launched, n_children = %d, n_launched = %d\n", slot.task->n_children, n_launched);
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: support memory-less logits computation
|
||||
if (slot.need_logits() && !llama_get_memory(ctx)) {
|
||||
send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!slot.can_split()) {
|
||||
if (slot.task->n_tokens() > n_ubatch) {
|
||||
send_error(slot,
|
||||
string_format(
|
||||
"input (%d tokens) is too large to process. increase the physical batch "
|
||||
"size (current batch size: %d)",
|
||||
slot.task->n_tokens(), n_ubatch),
|
||||
ERROR_TYPE_SERVER);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (slot.task->n_tokens() > slot.n_ctx) {
|
||||
send_error(
|
||||
slot,
|
||||
string_format(
|
||||
"input (%d tokens) is larger than the max context size (%d tokens). skipping",
|
||||
slot.task->n_tokens(), slot.n_ctx),
|
||||
ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (slot.task->n_tokens() >= slot.n_ctx) {
|
||||
send_error(slot,
|
||||
string_format("request (%d tokens) exceeds the available context size (%d "
|
||||
"tokens), try increasing it",
|
||||
slot.task->n_tokens(), slot.n_ctx),
|
||||
ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (slot.task->params.cache_prompt) {
|
||||
// reuse any previously computed tokens that are common with the new prompt
|
||||
n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
|
||||
|
||||
// if there is an alora invoked, don't cache after the invocation start
|
||||
if (slot.alora_invocation_start > 0) {
|
||||
SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start);
|
||||
n_past = std::min(n_past, slot.alora_invocation_start - 1);
|
||||
}
|
||||
|
||||
const auto n_cache_reuse = slot.task->params.n_cache_reuse;
|
||||
|
||||
const bool can_cache_reuse =
|
||||
llama_memory_can_shift(llama_get_memory(ctx)) &&
|
||||
!slot.prompt.tokens.has_mtmd;
|
||||
|
||||
if (!can_cache_reuse && n_cache_reuse > 0) {
|
||||
SLT_WRN(slot, "cache reuse is not supported - ignoring n_cache_reuse = %d\n", n_cache_reuse);
|
||||
}
|
||||
|
||||
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||
if (can_cache_reuse && n_cache_reuse > 0) {
|
||||
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
||||
|
||||
size_t head_c = n_past; // cache
|
||||
size_t head_p = n_past; // current prompt
|
||||
|
||||
if (mctx) {
|
||||
// we should never reach this
|
||||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", n_cache_reuse, n_past);
|
||||
|
||||
while (head_c < slot.prompt.tokens.size() &&
|
||||
head_p < input_tokens.size()) {
|
||||
|
||||
size_t n_match = 0;
|
||||
while (head_c + n_match < slot.prompt.tokens.size() &&
|
||||
head_p + n_match < input_tokens.size() &&
|
||||
slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
|
||||
n_match++;
|
||||
}
|
||||
|
||||
if (n_match >= (size_t) n_cache_reuse) {
|
||||
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
|
||||
//for (size_t i = head_p; i < head_p + n_match; i++) {
|
||||
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
//}
|
||||
|
||||
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
|
||||
|
||||
llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c);
|
||||
llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift);
|
||||
|
||||
for (size_t i = 0; i < n_match; i++) {
|
||||
slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
|
||||
n_past++;
|
||||
}
|
||||
|
||||
head_c += n_match;
|
||||
head_p += n_match;
|
||||
} else {
|
||||
head_c += 1;
|
||||
}
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past);
|
||||
}
|
||||
} else {
|
||||
// if we don't cache the prompt, we have to remove all previous tokens
|
||||
n_past = 0;
|
||||
}
|
||||
|
||||
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
|
||||
const auto n_swa = std::max(1, llama_model_n_swa(model));
|
||||
|
||||
// the largest pos_min required for a checkpoint to be useful
|
||||
const auto pos_min_thold = std::max(0, n_past - n_swa);
|
||||
|
||||
// note: disallow with mtmd contexts for now
|
||||
// https://github.com/ggml-org/llama.cpp/issues/17043
|
||||
if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) {
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
if (pos_min == -1) {
|
||||
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
|
||||
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
||||
}
|
||||
|
||||
// when the prompt prefix does not match, print the tokens around the mismatch
|
||||
// this is useful for debugging prompt caching
|
||||
if (slots_debug) {
|
||||
const int np0 = std::max<int>(n_past - 4, 0);
|
||||
const int np1 = std::min<int>(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
|
||||
|
||||
std::stringstream ss0;
|
||||
std::stringstream ss1;
|
||||
|
||||
std::stringstream st0;
|
||||
std::stringstream st1;
|
||||
|
||||
ss0 << "old: ... ";
|
||||
ss1 << "new: ... ";
|
||||
|
||||
for (int i = np0; i < np1; i++) {
|
||||
if (i == n_past) {
|
||||
ss0 << " | ";
|
||||
ss1 << " | ";
|
||||
}
|
||||
|
||||
{
|
||||
const auto token = slot.prompt.tokens[i];
|
||||
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
|
||||
ss0 << piece;
|
||||
st0 << std::setw(8) << token;
|
||||
}
|
||||
|
||||
{
|
||||
const auto token = slot.task->tokens[i];
|
||||
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
|
||||
ss1 << piece;
|
||||
st1 << std::setw(8) << token;
|
||||
}
|
||||
}
|
||||
|
||||
SLT_WRN(slot, "%s\n", ss0.str().c_str());
|
||||
SLT_WRN(slot, "%s\n", ss1.str().c_str());
|
||||
|
||||
SLT_WRN(slot, "%s\n", st0.str().c_str());
|
||||
SLT_WRN(slot, "%s\n", st1.str().c_str());
|
||||
}
|
||||
|
||||
if (pos_min > pos_min_thold) {
|
||||
// TODO: support can be added in the future when corresponding vision models get released
|
||||
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
||||
|
||||
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
|
||||
|
||||
// search for a context checkpoint
|
||||
const auto it = std::find_if(
|
||||
slot.prompt.checkpoints.rbegin(),
|
||||
slot.prompt.checkpoints.rend(),
|
||||
[&](const auto & cur) {
|
||||
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
|
||||
return cur.pos_min < pos_min_thold;
|
||||
}
|
||||
);
|
||||
|
||||
bool do_reset = it == slot.prompt.checkpoints.rend();
|
||||
|
||||
if (!do_reset) {
|
||||
// restore the context checkpoint
|
||||
const size_t checkpoint_size = it->data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
if (n != checkpoint_size) {
|
||||
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
|
||||
do_reset = true;
|
||||
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
|
||||
} else {
|
||||
n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
|
||||
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
|
||||
if (do_reset) {
|
||||
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
|
||||
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||
n_past = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// erase any checkpoints with pos_min > pos_min_thold
|
||||
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
|
||||
const auto & cur = *it;
|
||||
if (cur.pos_min > pos_min_thold) {
|
||||
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
|
||||
it = slot.prompt.checkpoints.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// [TAG_PROMPT_LOGITS]
|
||||
if (n_past == slot.task->n_tokens() && n_past > 0) {
|
||||
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens());
|
||||
n_past--;
|
||||
SLT_WRN(slot, "n_past was set to %d\n", n_past);
|
||||
}
|
||||
|
||||
slot.n_prompt_tokens_cache = n_past;
|
||||
slot.n_prompt_tokens_processed = 0;
|
||||
|
||||
slot.prompt.tokens.keep_first(n_past);
|
||||
|
||||
// send initial 0% progress update if needed
|
||||
// this is to signal the client that the request has started processing
|
||||
if (slot.task->params.stream && slot.task->params.return_progress) {
|
||||
send_partial_response(slot, {}, true);
|
||||
}
|
||||
}
|
||||
|
||||
const auto & input_tokens = slot.task->tokens;
|
||||
|
||||
slot.t_start_process_prompt = ggml_time_us();
|
||||
slot.t_start_generation = 0;
|
||||
|
||||
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
||||
|
||||
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
|
||||
slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
|
||||
|
||||
// print prompt tokens (for debugging)
|
||||
/*if (1) {
|
||||
// first 16 tokens (avoid flooding logs)
|
||||
for (int i = 0; i < std::min<int>(16, input_tokens.size()); i++) {
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
|
||||
}
|
||||
} else {
|
||||
// all
|
||||
for (int i = 0; i < (int) input_tokens.size(); i++) {
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
|
||||
}
|
||||
}*/
|
||||
|
||||
// keep track how many tokens we can reuse from the previous state
|
||||
int n_past = 0;
|
||||
|
||||
// empty prompt passed -> release the slot and send empty response
|
||||
if (input_tokens.empty()) {
|
||||
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
|
||||
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
slot.release();
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: support memory-less logits computation
|
||||
if (slot.need_logits() && !llama_get_memory(ctx)) {
|
||||
send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!slot.can_split()) {
|
||||
if (slot.task->n_tokens() > n_ubatch) {
|
||||
send_error(slot,
|
||||
string_format(
|
||||
"input (%d tokens) is too large to process. increase the physical batch "
|
||||
"size (current batch size: %d)",
|
||||
slot.task->n_tokens(), n_ubatch),
|
||||
ERROR_TYPE_SERVER);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (slot.task->n_tokens() > slot.n_ctx) {
|
||||
send_error(
|
||||
slot,
|
||||
string_format(
|
||||
"input (%d tokens) is larger than the max context size (%d tokens). skipping",
|
||||
slot.task->n_tokens(), slot.n_ctx),
|
||||
ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (slot.task->n_tokens() >= slot.n_ctx) {
|
||||
send_error(slot,
|
||||
string_format("request (%d tokens) exceeds the available context size (%d "
|
||||
"tokens), try increasing it",
|
||||
slot.task->n_tokens(), slot.n_ctx),
|
||||
ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (slot.task->params.cache_prompt) {
|
||||
// reuse any previously computed tokens that are common with the new prompt
|
||||
n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
|
||||
|
||||
// if there is an alora invoked, don't cache after the invocation start
|
||||
if (slot.alora_invocation_start > 0) {
|
||||
SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start);
|
||||
n_past = std::min(n_past, slot.alora_invocation_start - 1);
|
||||
}
|
||||
|
||||
const auto n_cache_reuse = slot.task->params.n_cache_reuse;
|
||||
|
||||
const bool can_cache_reuse =
|
||||
llama_memory_can_shift(llama_get_memory(ctx)) &&
|
||||
!slot.prompt.tokens.has_mtmd;
|
||||
|
||||
if (!can_cache_reuse && n_cache_reuse > 0) {
|
||||
SLT_WRN(slot, "cache reuse is not supported - ignoring n_cache_reuse = %d\n", n_cache_reuse);
|
||||
}
|
||||
|
||||
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||
if (can_cache_reuse && n_cache_reuse > 0) {
|
||||
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
||||
|
||||
size_t head_c = n_past; // cache
|
||||
size_t head_p = n_past; // current prompt
|
||||
|
||||
if (mctx) {
|
||||
// we should never reach this
|
||||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", n_cache_reuse, n_past);
|
||||
|
||||
while (head_c < slot.prompt.tokens.size() &&
|
||||
head_p < input_tokens.size()) {
|
||||
|
||||
size_t n_match = 0;
|
||||
while (head_c + n_match < slot.prompt.tokens.size() &&
|
||||
head_p + n_match < input_tokens.size() &&
|
||||
slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
|
||||
n_match++;
|
||||
}
|
||||
|
||||
if (n_match >= (size_t) n_cache_reuse) {
|
||||
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
|
||||
//for (size_t i = head_p; i < head_p + n_match; i++) {
|
||||
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
//}
|
||||
|
||||
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
|
||||
|
||||
llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c);
|
||||
llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift);
|
||||
|
||||
for (size_t i = 0; i < n_match; i++) {
|
||||
slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
|
||||
n_past++;
|
||||
}
|
||||
|
||||
head_c += n_match;
|
||||
head_p += n_match;
|
||||
} else {
|
||||
head_c += 1;
|
||||
}
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past);
|
||||
}
|
||||
} else {
|
||||
// if we don't cache the prompt, we have to remove all previous tokens
|
||||
n_past = 0;
|
||||
}
|
||||
|
||||
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
|
||||
const auto n_swa = std::max(1, llama_model_n_swa(model));
|
||||
|
||||
// the largest pos_min required for a checkpoint to be useful
|
||||
const auto pos_min_thold = std::max(0, n_past - n_swa);
|
||||
|
||||
// note: disallow with mtmd contexts for now
|
||||
// https://github.com/ggml-org/llama.cpp/issues/17043
|
||||
if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) {
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
if (pos_min == -1) {
|
||||
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
|
||||
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
||||
}
|
||||
|
||||
// when the prompt prefix does not match, print the tokens around the mismatch
|
||||
// this is useful for debugging prompt caching
|
||||
if (slots_debug) {
|
||||
const int np0 = std::max<int>(n_past - 4, 0);
|
||||
const int np1 = std::min<int>(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
|
||||
|
||||
std::stringstream ss0;
|
||||
std::stringstream ss1;
|
||||
|
||||
std::stringstream st0;
|
||||
std::stringstream st1;
|
||||
|
||||
ss0 << "old: ... ";
|
||||
ss1 << "new: ... ";
|
||||
|
||||
for (int i = np0; i < np1; i++) {
|
||||
if (i == n_past) {
|
||||
ss0 << " | ";
|
||||
ss1 << " | ";
|
||||
}
|
||||
|
||||
{
|
||||
const auto token = slot.prompt.tokens[i];
|
||||
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
|
||||
ss0 << piece;
|
||||
st0 << std::setw(8) << token;
|
||||
}
|
||||
|
||||
{
|
||||
const auto token = slot.task->tokens[i];
|
||||
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
|
||||
ss1 << piece;
|
||||
st1 << std::setw(8) << token;
|
||||
}
|
||||
}
|
||||
|
||||
SLT_WRN(slot, "%s\n", ss0.str().c_str());
|
||||
SLT_WRN(slot, "%s\n", ss1.str().c_str());
|
||||
|
||||
SLT_WRN(slot, "%s\n", st0.str().c_str());
|
||||
SLT_WRN(slot, "%s\n", st1.str().c_str());
|
||||
}
|
||||
|
||||
if (pos_min > pos_min_thold) {
|
||||
// TODO: support can be added in the future when corresponding vision models get released
|
||||
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
||||
|
||||
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
|
||||
|
||||
// search for a context checkpoint
|
||||
const auto it = std::find_if(
|
||||
slot.prompt.checkpoints.rbegin(),
|
||||
slot.prompt.checkpoints.rend(),
|
||||
[&](const auto & cur) {
|
||||
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
|
||||
return cur.pos_min < pos_min_thold;
|
||||
}
|
||||
);
|
||||
|
||||
bool do_reset = it == slot.prompt.checkpoints.rend();
|
||||
|
||||
if (!do_reset) {
|
||||
// restore the context checkpoint
|
||||
const size_t checkpoint_size = it->data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
if (n != checkpoint_size) {
|
||||
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
|
||||
do_reset = true;
|
||||
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
|
||||
} else {
|
||||
n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
|
||||
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
|
||||
if (do_reset) {
|
||||
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
|
||||
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||
n_past = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// erase any checkpoints with pos_min > pos_min_thold
|
||||
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
|
||||
const auto & cur = *it;
|
||||
if (cur.pos_min > pos_min_thold) {
|
||||
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
|
||||
it = slot.prompt.checkpoints.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// [TAG_PROMPT_LOGITS]
|
||||
if (n_past == slot.task->n_tokens() && n_past > 0) {
|
||||
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens());
|
||||
n_past--;
|
||||
SLT_WRN(slot, "n_past was set to %d\n", n_past);
|
||||
}
|
||||
|
||||
slot.n_prompt_tokens_cache = n_past;
|
||||
slot.n_prompt_tokens_processed = 0;
|
||||
|
||||
slot.prompt.tokens.keep_first(n_past);
|
||||
|
||||
// send initial 0% progress update if needed
|
||||
// this is to signal the client that the request has started processing
|
||||
if (slot.task->params.stream && slot.task->params.return_progress) {
|
||||
send_partial_response(slot, {}, true);
|
||||
}
|
||||
}
|
||||
|
||||
// SLOT_STATE_PROCESSING_PROMPT -> SLOT_STATE_DONE_PROMPT
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT) {
|
||||
const auto & input_tokens = slot.task->tokens;
|
||||
|
||||
if (!slot.can_split()) {
|
||||
// cannot fit the prompt in the current batch - will try next iter
|
||||
if (batch.n_tokens + slot.task->n_tokens() > n_batch) {
|
||||
@@ -2355,7 +2424,7 @@ private:
|
||||
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
|
||||
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
|
||||
|
||||
clear_slot(slot, /*allow_processing=*/true);
|
||||
slot.clear(true);
|
||||
|
||||
// there is no common part left
|
||||
slot.n_prompt_tokens_cache = 0;
|
||||
@@ -2455,16 +2524,6 @@ private:
|
||||
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
|
||||
common_sampler_reset(slot.smpl.get());
|
||||
|
||||
// Process all prompt tokens through sampler system
|
||||
for (int i = 0; i < slot.task->n_tokens(); ++i) {
|
||||
llama_token id = input_tokens[i];
|
||||
if (id != LLAMA_TOKEN_NULL) {
|
||||
common_sampler_accept(slot.smpl.get(), id, false);
|
||||
}
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
@@ -2473,6 +2532,8 @@ private:
|
||||
|
||||
SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
|
||||
|
||||
slot.init_sampler();
|
||||
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
|
||||
|
||||
@@ -2519,11 +2580,6 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
SRV_WRN("%s", "no tokens to decode\n");
|
||||
return;
|
||||
}
|
||||
|
||||
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
|
||||
|
||||
if (slot_batched) {
|
||||
@@ -2540,6 +2596,10 @@ private:
|
||||
llama_set_embeddings(ctx, slot_batched->need_embd());
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
SRV_WRN("%s", "no tokens to decode\n");
|
||||
}
|
||||
|
||||
int32_t i_next = 0;
|
||||
|
||||
// process the created batch of tokens
|
||||
@@ -2591,7 +2651,7 @@ private:
|
||||
|
||||
// note: it's complicated to keep track of how much of the current batch has been
|
||||
// processed before the error occurred, so we simply clear the entire context
|
||||
clear_slot(slot);
|
||||
slot.clear(false);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2615,27 +2675,34 @@ private:
|
||||
// on successful decode, restore the original batch size
|
||||
n_batch = llama_n_batch(ctx);
|
||||
|
||||
// handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too
|
||||
for (auto & slot : slots) {
|
||||
// may need to copy state to other slots
|
||||
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
|
||||
std::vector<server_slot *> child_slots;
|
||||
SLT_INF(slot, "parent task prompt done, n_children = %d\n", slot.task->n_children);
|
||||
|
||||
std::vector<server_slot *> children;
|
||||
for (auto & other : slots) {
|
||||
if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
|
||||
child_slots.push_back(&other);
|
||||
children.push_back(&other);
|
||||
}
|
||||
}
|
||||
|
||||
// we can only proceed if all child slots are having the correct tasks
|
||||
if (child_slots.size() == slot.task->n_children) {
|
||||
if (slot.task->n_children == (int) children.size()) {
|
||||
// copy state to the child slots
|
||||
for (auto & child : child_slots) {
|
||||
SLT_INF(slot, "copying state to child %d\n", child->id);
|
||||
for (auto & child : children) {
|
||||
SLT_INF(slot, " - copying state to child %d\n", child->id);
|
||||
|
||||
GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER);
|
||||
|
||||
slot.copy_state_to(*child);
|
||||
child->state = SLOT_STATE_DONE_PROMPT;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto & slot : slots) {
|
||||
// optionally send prompt processing progress
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||
if (slot.task->params.stream && slot.task->params.return_progress) {
|
||||
@@ -2720,7 +2787,7 @@ private:
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t n_draft = slot.drafted.size();
|
||||
const size_t n_draft = slot.drafted.size();
|
||||
|
||||
// the accepted tokens from the speculation
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
|
||||
@@ -2923,9 +2990,11 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
task.params.oaicompat_model = meta->model_name;
|
||||
|
||||
// prepare child tasks
|
||||
if (task.params.n_cmpl > 1) {
|
||||
task.n_children = task.params.n_cmpl - 1;
|
||||
for (size_t j = 0; j < task.n_children; j++) {
|
||||
|
||||
for (int j = 0; j < task.n_children; j++) {
|
||||
server_task child = task.create_child(task.id, rd.get_new_id());
|
||||
|
||||
// use different sampling seed for each child
|
||||
@@ -2938,7 +3007,8 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
}
|
||||
}
|
||||
|
||||
tasks.push_back(std::move(task));
|
||||
// note: the parent task always launches first
|
||||
tasks.insert(tasks.begin(), std::move(task));
|
||||
}
|
||||
|
||||
rd.post_tasks(std::move(tasks));
|
||||
|
||||
@@ -121,8 +121,8 @@ struct server_task {
|
||||
int id_slot = -1;
|
||||
|
||||
// used by parallel sampling (multiple completions from same prompt)
|
||||
size_t n_children = 0; // number of tasks reusing this prompt
|
||||
int id_parent = -1;
|
||||
int n_children = 0; // number of tasks reusing this prompt
|
||||
int id_parent = -1;
|
||||
|
||||
// used by SERVER_TASK_TYPE_INFERENCE
|
||||
task_params params;
|
||||
@@ -173,11 +173,13 @@ struct server_task {
|
||||
|
||||
server_task create_child(int id_parent, int id_child) const {
|
||||
server_task copy;
|
||||
|
||||
copy.id = id_child;
|
||||
copy.id_parent = id_parent;
|
||||
copy.params = params;
|
||||
copy.type = type;
|
||||
copy.tokens = tokens.clone();
|
||||
|
||||
return copy;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user