server : simplify prompt state transition branches

This commit is contained in:
Georgi Gerganov
2026-01-09 17:46:03 +02:00
parent cc5cafecf4
commit 4a2751258a

View File

@@ -2104,8 +2104,9 @@ private:
continue;
}
// this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
// SLOT_STATE_STARTED -> SLOT_STATE_PROCESSING_PROMPT
// TODO: maybe move branch to outside of this loop in the future
if (slot.state == SLOT_STATE_STARTED) {
// wait for all children to be launched
if (slot.is_parent()) {
int n_launched = 0;
@@ -2123,289 +2124,291 @@ private:
const auto & input_tokens = slot.task->tokens;
// TODO: maybe move branch to outside of this loop in the future
if (slot.state == SLOT_STATE_STARTED) {
slot.t_start_process_prompt = ggml_time_us();
slot.t_start_generation = 0;
slot.t_start_process_prompt = ggml_time_us();
slot.t_start_generation = 0;
slot.state = SLOT_STATE_PROCESSING_PROMPT;
slot.state = SLOT_STATE_PROCESSING_PROMPT;
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
// print prompt tokens (for debugging)
/*if (1) {
// first 16 tokens (avoid flooding logs)
for (int i = 0; i < std::min<int>(16, input_tokens.size()); i++) {
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
}
} else {
// all
for (int i = 0; i < (int) input_tokens.size(); i++) {
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
}
}*/
// keep track how many tokens we can reuse from the previous state
int n_past = 0;
// empty prompt passed -> release the slot and send empty response
if (input_tokens.empty()) {
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
slot.print_timings();
send_final_response(slot);
slot.release();
continue;
// print prompt tokens (for debugging)
/*if (1) {
// first 16 tokens (avoid flooding logs)
for (int i = 0; i < std::min<int>(16, input_tokens.size()); i++) {
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
}
} else {
// all
for (int i = 0; i < (int) input_tokens.size(); i++) {
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
}
}*/
// TODO: support memory-less logits computation
if (slot.need_logits() && !llama_get_memory(ctx)) {
send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
// keep track how many tokens we can reuse from the previous state
int n_past = 0;
// empty prompt passed -> release the slot and send empty response
if (input_tokens.empty()) {
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
slot.print_timings();
send_final_response(slot);
slot.release();
continue;
}
// TODO: support memory-less logits computation
if (slot.need_logits() && !llama_get_memory(ctx)) {
send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
slot.release();
continue;
}
if (!slot.can_split()) {
if (slot.task->n_tokens() > n_ubatch) {
send_error(slot,
string_format(
"input (%d tokens) is too large to process. increase the physical batch "
"size (current batch size: %d)",
slot.task->n_tokens(), n_ubatch),
ERROR_TYPE_SERVER);
slot.release();
continue;
}
if (!slot.can_split()) {
if (slot.task->n_tokens() > n_ubatch) {
send_error(slot,
string_format(
"input (%d tokens) is too large to process. increase the physical batch "
"size (current batch size: %d)",
slot.task->n_tokens(), n_ubatch),
ERROR_TYPE_SERVER);
slot.release();
continue;
if (slot.task->n_tokens() > slot.n_ctx) {
send_error(
slot,
string_format(
"input (%d tokens) is larger than the max context size (%d tokens). skipping",
slot.task->n_tokens(), slot.n_ctx),
ERROR_TYPE_EXCEED_CONTEXT_SIZE);
slot.release();
continue;
}
} else {
if (slot.task->n_tokens() >= slot.n_ctx) {
send_error(slot,
string_format("request (%d tokens) exceeds the available context size (%d "
"tokens), try increasing it",
slot.task->n_tokens(), slot.n_ctx),
ERROR_TYPE_EXCEED_CONTEXT_SIZE);
slot.release();
continue;
}
if (slot.task->params.cache_prompt) {
// reuse any previously computed tokens that are common with the new prompt
n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
// if there is an alora invoked, don't cache after the invocation start
if (slot.alora_invocation_start > 0) {
SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start);
n_past = std::min(n_past, slot.alora_invocation_start - 1);
}
if (slot.task->n_tokens() > slot.n_ctx) {
send_error(
slot,
string_format(
"input (%d tokens) is larger than the max context size (%d tokens). skipping",
slot.task->n_tokens(), slot.n_ctx),
ERROR_TYPE_EXCEED_CONTEXT_SIZE);
slot.release();
continue;
}
} else {
if (slot.task->n_tokens() >= slot.n_ctx) {
send_error(slot,
string_format("request (%d tokens) exceeds the available context size (%d "
"tokens), try increasing it",
slot.task->n_tokens(), slot.n_ctx),
ERROR_TYPE_EXCEED_CONTEXT_SIZE);
slot.release();
continue;
const auto n_cache_reuse = slot.task->params.n_cache_reuse;
const bool can_cache_reuse =
llama_memory_can_shift(llama_get_memory(ctx)) &&
!slot.prompt.tokens.has_mtmd;
if (!can_cache_reuse && n_cache_reuse > 0) {
SLT_WRN(slot, "cache reuse is not supported - ignoring n_cache_reuse = %d\n", n_cache_reuse);
}
if (slot.task->params.cache_prompt) {
// reuse any previously computed tokens that are common with the new prompt
n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
// reuse chunks from the cached prompt by shifting their KV cache in the new position
if (can_cache_reuse && n_cache_reuse > 0) {
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
// if there is an alora invoked, don't cache after the invocation start
if (slot.alora_invocation_start > 0) {
SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start);
n_past = std::min(n_past, slot.alora_invocation_start - 1);
size_t head_c = n_past; // cache
size_t head_p = n_past; // current prompt
if (mctx) {
// we should never reach this
GGML_ABORT("not supported by multimodal");
}
const auto n_cache_reuse = slot.task->params.n_cache_reuse;
SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", n_cache_reuse, n_past);
const bool can_cache_reuse =
llama_memory_can_shift(llama_get_memory(ctx)) &&
!slot.prompt.tokens.has_mtmd;
while (head_c < slot.prompt.tokens.size() &&
head_p < input_tokens.size()) {
if (!can_cache_reuse && n_cache_reuse > 0) {
SLT_WRN(slot, "cache reuse is not supported - ignoring n_cache_reuse = %d\n", n_cache_reuse);
}
// reuse chunks from the cached prompt by shifting their KV cache in the new position
if (can_cache_reuse && n_cache_reuse > 0) {
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
size_t head_c = n_past; // cache
size_t head_p = n_past; // current prompt
if (mctx) {
// we should never reach this
GGML_ABORT("not supported by multimodal");
size_t n_match = 0;
while (head_c + n_match < slot.prompt.tokens.size() &&
head_p + n_match < input_tokens.size() &&
slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
n_match++;
}
SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", n_cache_reuse, n_past);
if (n_match >= (size_t) n_cache_reuse) {
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
//for (size_t i = head_p; i < head_p + n_match; i++) {
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
//}
while (head_c < slot.prompt.tokens.size() &&
head_p < input_tokens.size()) {
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
size_t n_match = 0;
while (head_c + n_match < slot.prompt.tokens.size() &&
head_p + n_match < input_tokens.size() &&
slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
n_match++;
llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c);
llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift);
for (size_t i = 0; i < n_match; i++) {
slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
n_past++;
}
if (n_match >= (size_t) n_cache_reuse) {
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
//for (size_t i = head_p; i < head_p + n_match; i++) {
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
//}
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c);
llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift);
for (size_t i = 0; i < n_match; i++) {
slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
n_past++;
}
head_c += n_match;
head_p += n_match;
} else {
head_c += 1;
}
}
SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past);
}
} else {
// if we don't cache the prompt, we have to remove all previous tokens
n_past = 0;
}
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
const auto n_swa = std::max(1, llama_model_n_swa(model));
// the largest pos_min required for a checkpoint to be useful
const auto pos_min_thold = std::max(0, n_past - n_swa);
// note: disallow with mtmd contexts for now
// https://github.com/ggml-org/llama.cpp/issues/17043
if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
if (pos_min == -1) {
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
}
// when the prompt prefix does not match, print the tokens around the mismatch
// this is useful for debugging prompt caching
if (slots_debug) {
const int np0 = std::max<int>(n_past - 4, 0);
const int np1 = std::min<int>(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
std::stringstream ss0;
std::stringstream ss1;
std::stringstream st0;
std::stringstream st1;
ss0 << "old: ... ";
ss1 << "new: ... ";
for (int i = np0; i < np1; i++) {
if (i == n_past) {
ss0 << " | ";
ss1 << " | ";
}
{
const auto token = slot.prompt.tokens[i];
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
ss0 << piece;
st0 << std::setw(8) << token;
}
{
const auto token = slot.task->tokens[i];
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
ss1 << piece;
st1 << std::setw(8) << token;
}
}
SLT_WRN(slot, "%s\n", ss0.str().c_str());
SLT_WRN(slot, "%s\n", ss1.str().c_str());
SLT_WRN(slot, "%s\n", st0.str().c_str());
SLT_WRN(slot, "%s\n", st1.str().c_str());
}
if (pos_min > pos_min_thold) {
// TODO: support can be added in the future when corresponding vision models get released
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
// search for a context checkpoint
const auto it = std::find_if(
slot.prompt.checkpoints.rbegin(),
slot.prompt.checkpoints.rend(),
[&](const auto & cur) {
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
return cur.pos_min < pos_min_thold;
}
);
bool do_reset = it == slot.prompt.checkpoints.rend();
if (!do_reset) {
// restore the context checkpoint
const size_t checkpoint_size = it->data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != checkpoint_size) {
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
do_reset = true;
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
} else {
n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
}
}
if (do_reset) {
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
n_past = 0;
}
}
}
{
// erase any checkpoints with pos_min > pos_min_thold
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
const auto & cur = *it;
if (cur.pos_min > pos_min_thold) {
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
it = slot.prompt.checkpoints.erase(it);
head_c += n_match;
head_p += n_match;
} else {
++it;
head_c += 1;
}
}
SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past);
}
} else {
// if we don't cache the prompt, we have to remove all previous tokens
n_past = 0;
}
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
const auto n_swa = std::max(1, llama_model_n_swa(model));
// the largest pos_min required for a checkpoint to be useful
const auto pos_min_thold = std::max(0, n_past - n_swa);
// note: disallow with mtmd contexts for now
// https://github.com/ggml-org/llama.cpp/issues/17043
if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
if (pos_min == -1) {
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
}
// when the prompt prefix does not match, print the tokens around the mismatch
// this is useful for debugging prompt caching
if (slots_debug) {
const int np0 = std::max<int>(n_past - 4, 0);
const int np1 = std::min<int>(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
std::stringstream ss0;
std::stringstream ss1;
std::stringstream st0;
std::stringstream st1;
ss0 << "old: ... ";
ss1 << "new: ... ";
for (int i = np0; i < np1; i++) {
if (i == n_past) {
ss0 << " | ";
ss1 << " | ";
}
{
const auto token = slot.prompt.tokens[i];
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
ss0 << piece;
st0 << std::setw(8) << token;
}
{
const auto token = slot.task->tokens[i];
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
ss1 << piece;
st1 << std::setw(8) << token;
}
}
SLT_WRN(slot, "%s\n", ss0.str().c_str());
SLT_WRN(slot, "%s\n", ss1.str().c_str());
SLT_WRN(slot, "%s\n", st0.str().c_str());
SLT_WRN(slot, "%s\n", st1.str().c_str());
}
if (pos_min > pos_min_thold) {
// TODO: support can be added in the future when corresponding vision models get released
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
// search for a context checkpoint
const auto it = std::find_if(
slot.prompt.checkpoints.rbegin(),
slot.prompt.checkpoints.rend(),
[&](const auto & cur) {
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
return cur.pos_min < pos_min_thold;
}
);
bool do_reset = it == slot.prompt.checkpoints.rend();
if (!do_reset) {
// restore the context checkpoint
const size_t checkpoint_size = it->data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != checkpoint_size) {
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
do_reset = true;
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
} else {
n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
}
}
if (do_reset) {
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
n_past = 0;
}
}
}
// [TAG_PROMPT_LOGITS]
if (n_past == slot.task->n_tokens() && n_past > 0) {
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens());
n_past--;
SLT_WRN(slot, "n_past was set to %d\n", n_past);
}
slot.n_prompt_tokens_cache = n_past;
slot.n_prompt_tokens_processed = 0;
slot.prompt.tokens.keep_first(n_past);
// send initial 0% progress update if needed
// this is to signal the client that the request has started processing
if (slot.task->params.stream && slot.task->params.return_progress) {
send_partial_response(slot, {}, true);
{
// erase any checkpoints with pos_min > pos_min_thold
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
const auto & cur = *it;
if (cur.pos_min > pos_min_thold) {
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
it = slot.prompt.checkpoints.erase(it);
} else {
++it;
}
}
}
}
// [TAG_PROMPT_LOGITS]
if (n_past == slot.task->n_tokens() && n_past > 0) {
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens());
n_past--;
SLT_WRN(slot, "n_past was set to %d\n", n_past);
}
slot.n_prompt_tokens_cache = n_past;
slot.n_prompt_tokens_processed = 0;
slot.prompt.tokens.keep_first(n_past);
// send initial 0% progress update if needed
// this is to signal the client that the request has started processing
if (slot.task->params.stream && slot.task->params.return_progress) {
send_partial_response(slot, {}, true);
}
}
// SLOT_STATE_PROCESSING_PROMPT -> SLOT_STATE_DONE_PROMPT
if (slot.state == SLOT_STATE_PROCESSING_PROMPT) {
const auto & input_tokens = slot.task->tokens;
if (!slot.can_split()) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.task->n_tokens() > n_batch) {