Compare commits

...

8 Commits

Author SHA1 Message Date
Georgi Gerganov
dfceb012ee llama : add "virtual sequences"
ggml-ci
2025-07-02 20:26:55 +03:00
Georgi Gerganov
30b4d4e1b3 ggml : add TODOs for adding GGML_OP_SET_ROWS support in the backends
ggml-ci
2025-07-02 15:49:19 +03:00
Georgi Gerganov
5495ea96ba kv-cache : add comments
ggml-ci
2025-07-02 15:49:19 +03:00
Georgi Gerganov
f3da97e61b kv-cache : bounds-check when accessing slot_info indices 2025-07-02 15:49:18 +03:00
Georgi Gerganov
a70293bc25 kv-cache : improve find_slot impl 2025-07-02 15:49:18 +03:00
Georgi Gerganov
2ac5be3a58 cont : remove redundant ifs
ggml-ci
2025-07-02 15:49:18 +03:00
Georgi Gerganov
ac8f3474c8 graph : separate k and v indices
ggml-ci
2025-07-02 15:49:18 +03:00
Georgi Gerganov
cd811b7a9d kv-cache : use ggml_set_rows
ggml-ci
2025-07-02 15:49:16 +03:00
22 changed files with 1029 additions and 329 deletions

View File

@@ -235,7 +235,7 @@ int main(int argc, char ** argv) {
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
llama_batch batch = llama_batch_init(n_ctx*n_clients, 0, 1);
int32_t n_total_prompt = 0;
int32_t n_total_gen = 0;
@@ -289,6 +289,7 @@ int main(int argc, char ** argv) {
// all sequences have ended - clear the entire KV cache
for (int i = 1; i <= n_clients; ++i) {
llama_memory_seq_rm(mem, i, -1, -1);
// but keep the system prompt
llama_memory_seq_cp(mem, 0, i, -1, -1);
}

View File

@@ -2086,6 +2086,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
return false;
}
} break;
case GGML_OP_SET_ROWS:
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
return false;
} break;
case GGML_OP_CPY: {
ggml_tensor *src = op->src[0];
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||

View File

@@ -2222,6 +2222,12 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
default:
return false;
}
case GGML_OP_SET_ROWS:
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
return false;
} break;
case GGML_OP_CPY:
case GGML_OP_DUP:
case GGML_OP_CONT:

View File

@@ -4285,6 +4285,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return false;
}
}
case GGML_OP_SET_ROWS:
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
return false;
} break;
case GGML_OP_CPY:
{
ggml_type src0_type = op->src[0]->type;

View File

@@ -10333,6 +10333,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return false;
}
} break;
case GGML_OP_SET_ROWS:
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
return false;
} break;
case GGML_OP_CONT:
case GGML_OP_CPY:
case GGML_OP_DUP:

View File

@@ -4696,7 +4696,6 @@ struct ggml_tensor * ggml_flash_attn_ext(
if (mask) {
GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(mask->ne[2] == q->ne[3]);
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));

View File

@@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
// note: tracking the other way around is not necessary for now
//seq_cpl[s0][s1] = true;
has_cpl = true;
}
}
}
@@ -405,6 +407,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
return n_outputs;
}
uint32_t llama_batch_allocr::get_n_used() const {
return n_used;
}
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
return out_ids;
}
@@ -420,6 +426,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
void llama_batch_allocr::split_reset() {
out_ids.clear();
n_used = 0;
used.clear();
used.resize(get_n_tokens(), false);
@@ -444,6 +452,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
idxs.push_back(cur_idx);
used[cur_idx] = true;
++n_used;
++cur_idx;
@@ -459,9 +468,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
return ubatch_add(idxs, idxs.size(), false);
}
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
if (sequential && has_cpl) {
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
return {};
}
std::vector<seq_set_t> cur_seq_set;
llama_seq_id last_seq_id = -1;
// determine the non-overlapping sequence sets participating in this ubatch
for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (used[i]) {
@@ -478,9 +495,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
}
}
// accept only increasing sequence ids
if (sequential) {
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
}
if (add) {
cur_seq_set.push_back(seq_set[i]);
last_seq_id = batch.seq_id[i][0];
if (cur_seq_set.size() > n_ubatch) {
break;
}
@@ -529,6 +553,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
idxs_per_seq[s].push_back(idx);
used[idx] = true;
++n_used;
++cur_idx[s];
}
@@ -570,6 +595,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
idxs.push_back(cur_idx);
used[cur_idx] = true;
++n_used;
if (idxs.size() >= n_ubatch) {
break;

View File

@@ -54,6 +54,7 @@ public:
uint32_t get_n_tokens() const;
uint32_t get_n_outputs() const;
uint32_t get_n_used() const;
// the array of output indices in the order they were encountered during the ubatch splitting
std::vector<int32_t> & get_out_ids();
@@ -69,7 +70,8 @@ public:
llama_ubatch split_simple(uint32_t n_ubatch);
// make ubatches of equal-length sequences sets
llama_ubatch split_equal(uint32_t n_ubatch);
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
// sequence-set-wise split - each ubatch contains a single sequence-set
llama_ubatch split_seq(uint32_t n_ubatch);
@@ -112,6 +114,9 @@ private:
using pos_set_t = std::set<llama_pos>;
using seq_cpl_t = std::vector<bool>;
// helper flag to quickly determine if there are any coupled sequences in the batch
bool has_cpl;
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
@@ -125,6 +130,8 @@ private:
// batch indices of the output
std::vector<int32_t> out_ids;
uint32_t n_used;
// used[i] indicates if token i has already been used in a previous ubatch
std::vector<bool> used;

View File

@@ -33,6 +33,9 @@ llama_context::llama_context(
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
}
const char * LLAMA_HT = getenv("LLAMA_HT");
cparams.n_seq_virt = LLAMA_HT ? cparams.n_seq_max : 1;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
@@ -1308,7 +1311,8 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
this->n_outputs = n_outputs;
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
//llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens, 1);
auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);

View File

@@ -11,8 +11,9 @@ struct llama_cparams {
uint32_t n_batch;
uint32_t n_ubatch;
uint32_t n_seq_max;
int n_threads; // number of threads to use for generation
int n_threads_batch; // number of threads to use for batch processing
uint32_t n_seq_virt;
int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing
float rope_freq_base;
float rope_freq_scale;

View File

@@ -281,19 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
}
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
mctx->set_input_k_idxs(self_k_idxs, ubatch);
mctx->set_input_v_idxs(self_v_idxs, ubatch);
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
if (self_kq_mask_swa) {
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
@@ -333,9 +336,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
}
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch);
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
const int64_t n_rs = mctx->get_recr()->get_n_rs();
@@ -350,7 +354,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
}
}
void llm_graph_input_one::set_input(const llama_ubatch *) {
void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
GGML_ASSERT(one && ggml_nelements(one) == 1);
float f_one = 1.0f;
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
@@ -995,10 +1000,13 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1025,6 +1033,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
float kq_scale) const {
const bool v_trans = v->nb[1] > v->nb[2];
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_seqs, n_seqs);
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
@@ -1073,7 +1085,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
#endif
}
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs);
} else {
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
@@ -1118,7 +1130,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs);
if (!cparams.offload_kqv) {
// all nodes between the KV store and the attention output are run on the CPU
@@ -1196,10 +1208,13 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
const auto n_kv = mctx_cur->get_n_kv();
const auto n_kv = mctx_cur->get_n_kv();
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1230,8 +1245,11 @@ ggml_tensor * llm_graph_context::build_attn(
// store to KV cache
{
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
const auto & k_idxs = inp->get_k_idxs();
const auto & v_idxs = inp->get_v_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
}
const auto & kq_mask = inp->get_kq_mask();
@@ -1290,11 +1308,15 @@ ggml_tensor * llm_graph_context::build_attn(
// optionally store to KV cache
if (k_cur) {
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
}
if (v_cur) {
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
}
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@@ -1398,8 +1420,11 @@ ggml_tensor * llm_graph_context::build_attn(
// store to KV cache
{
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
const auto & k_idxs = inp->get_k_idxs();
const auto & v_idxs = inp->get_v_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
}
const auto & kq_mask = inp->get_kq_mask();
@@ -1431,11 +1456,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
{
const auto n_kv = mctx_cur->get_base()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1446,8 +1475,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;

View File

@@ -249,10 +249,16 @@ public:
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
const llama_hparams & hparams;
const llama_cparams & cparams;
@@ -274,13 +280,23 @@ public:
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
const llama_hparams & hparams;
const llama_cparams & cparams;
@@ -319,8 +335,14 @@ public:
ggml_tensor * s_copy; // I32 [kv_size]
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
@@ -336,7 +358,7 @@ public:
llm_graph_input_one() {}
virtual ~llm_graph_input_one() = default;
void set_input(const llama_ubatch *) override;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * one = nullptr; // F32
};

View File

@@ -20,14 +20,15 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
bool swa_full,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_seq_virt,
uint32_t n_ubatch,
uint32_t n_pad) : hparams(model.hparams) {
uint32_t n_pad) : hparams(model.hparams), n_seq_virt(n_seq_virt) {
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
const uint32_t size_base = kv_size;
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(n_seq_max/n_seq_virt) + n_ubatch, n_pad));
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
if (swa_full) {
@@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
kv_base = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_base), type_k, type_v,
v_trans, offload, size_base, n_seq_max, n_pad,
v_trans, offload, size_base, n_seq_max, n_seq_virt, n_pad,
0, LLAMA_SWA_TYPE_NONE);
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
kv_swa = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_swa), type_k, type_v,
v_trans, offload, size_swa, n_seq_max, n_pad,
v_trans, offload, size_swa, n_seq_max, n_seq_virt, n_pad,
hparams.n_swa, hparams.swa_type);
}
@@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
// first try simple split
do {
if (n_seq_virt > 1) {
// requires equal splits, so we skip the simple split
break;
}
balloc.split_reset();
std::vector<llama_ubatch> ubatches;
@@ -113,20 +119,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
ubatches.push_back(std::move(ubatch)); // NOLINT
}
auto heads_base = kv_base->prepare(ubatches);
if (heads_base.empty()) {
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}
auto heads_swa = kv_swa->prepare(ubatches);
if (heads_swa.empty()) {
auto sinfos_base = kv_base->prepare(ubatches);
if (sinfos_base.empty()) {
break;
}
assert(heads_base.size() == heads_swa.size());
auto sinfos_swa = kv_swa->prepare(ubatches);
if (sinfos_swa.empty()) {
break;
}
assert(sinfos_base.size() == sinfos_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_context>(
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
} while (false);
// if it fails, try equal split
@@ -135,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
std::vector<llama_ubatch> ubatches;
while (true) {
auto ubatch = balloc.split_equal(n_ubatch);
auto ubatch = balloc.split_equal(n_ubatch, n_seq_virt > 1);
if (ubatch.n_tokens == 0) {
break;
@@ -144,20 +155,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
ubatches.push_back(std::move(ubatch)); // NOLINT
}
auto heads_base = kv_base->prepare(ubatches);
if (heads_base.empty()) {
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}
auto heads_swa = kv_swa->prepare(ubatches);
if (heads_swa.empty()) {
auto sinfos_base = kv_base->prepare(ubatches);
if (sinfos_base.empty()) {
break;
}
assert(heads_base.size() == heads_swa.size());
auto sinfos_swa = kv_swa->prepare(ubatches);
if (sinfos_swa.empty()) {
break;
}
assert(sinfos_base.size() == sinfos_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_context>(
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
} while (false);
// TODO: if we fail again, we should attempt different splitting strategies
@@ -220,13 +236,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
std::vector<uint32_t> heads_base,
std::vector<uint32_t> heads_swa,
slot_info_vec_t sinfos_base,
slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}

View File

@@ -22,6 +22,7 @@ public:
bool swa_full,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_seq_virt,
uint32_t n_ubatch,
uint32_t n_pad);
@@ -68,12 +69,16 @@ public:
private:
const llama_hparams & hparams;
const uint32_t n_seq_virt = 1;
std::unique_ptr<llama_kv_cache_unified> kv_base;
std::unique_ptr<llama_kv_cache_unified> kv_swa;
};
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
public:
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
// used for errors
llama_kv_cache_unified_iswa_context(llama_memory_status status);
@@ -90,8 +95,8 @@ public:
// used to create a batch processing context from a batch
llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
std::vector<uint32_t> heads_base,
std::vector<uint32_t> heads_swa,
slot_info_vec_t sinfos_base,
slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_iswa_context();

File diff suppressed because it is too large Load Diff

View File

@@ -24,8 +24,6 @@ public:
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
using ubatch_heads = std::vector<uint32_t>;
struct defrag_info {
bool empty() const {
return ids.empty();
@@ -37,6 +35,50 @@ public:
std::vector<uint32_t> ids;
};
// for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
// KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
struct slot_info {
// data for ggml_set_rows
using idx_vec_t = std::vector<uint32_t>;
llama_seq_id s0;
llama_seq_id s1;
std::vector<llama_seq_id> seq_id_virt;
std::vector<idx_vec_t> idxs;
uint32_t head() const {
GGML_ASSERT(idxs.size() == 1);
return idxs.at(0).at(0);
}
void resize(size_t n) {
seq_id_virt.resize(n);
idxs.resize(n);
}
size_t size() const {
GGML_ASSERT(idxs.size() == seq_id_virt.size());
return idxs.at(0).size();
}
size_t n_seq_virt() const {
return seq_id_virt.size();
}
bool empty() const {
return idxs.empty();
}
void clear() {
idxs.clear();
}
};
using slot_info_vec_t = std::vector<slot_info>;
llama_kv_cache_unified(
const llama_model & model,
layer_filter_cb && filter,
@@ -46,6 +88,7 @@ public:
bool offload,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_seq_virt,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type);
@@ -98,36 +141,44 @@ public:
uint32_t get_n_kv() const;
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
// store k_cur and v_cur in the cache based on the provided head location
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
//
// preparation API
//
// find places for the provided ubatches in the cache, returns the head locations
// find places for the provided ubatches in the cache, returns the slot infos
// return empty vector on failure
ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
// return the cell position where we can insert the ubatch
// return -1 on failure to find a contiguous slot of kv cells
int32_t find_slot(const llama_ubatch & ubatch) const;
// find a slot of kv cells that can hold the ubatch
// if cont == true, then the slot must be continuous
// return empty slot_info on failure
slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
// emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);
// emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
//
// set_input API
// input API
//
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_k_shift(ggml_tensor * dst) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_k_shift (ggml_tensor * dst) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
private:
@@ -141,15 +192,15 @@ private:
ggml_tensor * k;
ggml_tensor * v;
std::vector<ggml_tensor *> k_seq;
std::vector<ggml_tensor *> v_seq;
};
bool v_trans = true; // the value tensor is transposed
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
uint32_t head = 0;
const uint32_t n_seq_max = 1;
const uint32_t n_seq_max = 1;
const uint32_t n_seq_virt = 1;
// required padding
const uint32_t n_pad = 1;
@@ -157,14 +208,26 @@ private:
// SWA
const uint32_t n_swa = 0;
// env: LLAMA_KV_CACHE_DEBUG
int debug = 0;
// env: LLAMA_SET_ROWS (temporary)
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
int supports_set_rows = false;
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
llama_kv_cells_unified cells;
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
std::vector<uint32_t> v_heads;
std::vector<llama_kv_cells_unified> v_cells;
// maps from a sequence id to a virtual sequence id
std::vector<uint32_t> seq_virt_idx;
std::vector<kv_layer> layers;
@@ -211,8 +274,8 @@ private:
class llama_kv_cache_unified_context : public llama_memory_context_i {
public:
// some shorthands
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
using defrag_info = llama_kv_cache_unified::defrag_info;
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
using defrag_info = llama_kv_cache_unified::defrag_info;
// used for errors
llama_kv_cache_unified_context(llama_memory_status status);
@@ -231,7 +294,7 @@ public:
// used to create a batch procesing context from a batch
llama_kv_cache_unified_context(
llama_kv_cache_unified * kv,
ubatch_heads heads,
slot_info_vec_t sinfos,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_context();
@@ -257,11 +320,16 @@ public:
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
// store k_cur and v_cur in the cache based on the provided head location
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
void set_input_k_shift(ggml_tensor * dst) const;
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_k_shift (ggml_tensor * dst) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
@@ -283,10 +351,10 @@ private:
// batch processing context
//
// the index of the next ubatch to process
size_t i_next = 0;
// the index of the cur ubatch to process
size_t i_cur = 0;
ubatch_heads heads;
slot_info_vec_t sinfos;
std::vector<llama_ubatch> ubatches;
@@ -297,7 +365,4 @@ private:
// a heuristic, to avoid attending the full cache if it is not yet utilized
// as the cache gets filled, the benefit from this heuristic disappears
int32_t n_kv;
// the beginning of the current slot in which the ubatch will be inserted
int32_t head;
};

View File

@@ -105,10 +105,30 @@ public:
res.resize(n);
for (uint32_t j = 0; j < n; ++j) {
res.pos[j] = pos[i + j];
res.seq[j] = seq[i + j];
const auto idx = i + j;
assert(shift[i + j] == 0);
res.pos[j] = pos[idx];
res.seq[j] = seq[idx];
assert(shift[idx] == 0);
}
return res;
}
// copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
llama_kv_cells_unified res;
res.resize(idxs.size());
for (uint32_t j = 0; j < idxs.size(); ++j) {
const auto idx = idxs[j];
res.pos[j] = pos[idx];
res.seq[j] = seq[idx];
assert(shift[idx] == 0);
}
return res;
@@ -119,26 +139,58 @@ public:
assert(i + other.pos.size() <= pos.size());
for (uint32_t j = 0; j < other.pos.size(); ++j) {
if (pos[i + j] == -1 && other.pos[j] != -1) {
const auto idx = i + j;
if (pos[idx] == -1 && other.pos[j] != -1) {
used.insert(i + j);
}
if (pos[i + j] != -1 && other.pos[j] == -1) {
if (pos[idx] != -1 && other.pos[j] == -1) {
used.erase(i + j);
}
if (pos[i + j] != -1) {
if (pos[idx] != -1) {
seq_pos_rm(i + j);
}
pos[i + j] = other.pos[j];
seq[i + j] = other.seq[j];
pos[idx] = other.pos[j];
seq[idx] = other.seq[j];
if (pos[i + j] != -1) {
if (pos[idx] != -1) {
seq_pos_add(i + j);
}
assert(shift[i + j] == 0);
assert(shift[idx] == 0);
}
}
// set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
assert(idxs.size() == other.pos.size());
for (uint32_t j = 0; j < other.pos.size(); ++j) {
const auto idx = idxs[j];
if (pos[idx] == -1 && other.pos[j] != -1) {
used.insert(idx);
}
if (pos[idx] != -1 && other.pos[j] == -1) {
used.erase(idx);
}
if (pos[idx] != -1) {
seq_pos_rm(idx);
}
pos[idx] = other.pos[j];
seq[idx] = other.seq[j];
if (pos[idx] != -1) {
seq_pos_add(idx);
}
assert(shift[idx] == 0);
}
}

View File

@@ -40,6 +40,7 @@ llama_memory_hybrid::llama_memory_hybrid(
offload,
kv_size,
n_seq_max,
1,
n_pad,
n_swa,
swa_type
@@ -70,7 +71,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
ubatch = balloc.split_equal(n_ubatch);
ubatch = balloc.split_equal(n_ubatch, false);
}
if (ubatch.n_tokens == 0) {
@@ -80,6 +81,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
ubatches.push_back(std::move(ubatch)); // NOLINT
}
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}
// prepare the recurrent batches first
if (!mem_recr->prepare(ubatches)) {
// TODO: will the recurrent cache be in an undefined context at this point?
@@ -195,11 +201,11 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
llama_memory_hybrid_context::llama_memory_hybrid_context(
llama_memory_hybrid * mem,
std::vector<uint32_t> heads_attn,
slot_info_vec_t sinfos_attn,
std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
}

View File

@@ -92,6 +92,8 @@ private:
class llama_memory_hybrid_context : public llama_memory_context_i {
public:
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
// init failure
explicit llama_memory_hybrid_context(llama_memory_status status);
@@ -107,7 +109,7 @@ public:
// init success
llama_memory_hybrid_context(
llama_memory_hybrid * mem,
std::vector<uint32_t> heads_attn,
slot_info_vec_t sinfos_attn,
std::vector<llama_ubatch> ubatches);
~llama_memory_hybrid_context() = default;

View File

@@ -374,10 +374,11 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
ubatch = balloc.split_equal(n_ubatch);
ubatch = balloc.split_equal(n_ubatch, false);
}
if (ubatch.n_tokens == 0) {
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}

View File

@@ -14499,6 +14499,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
params.swa_full,
cparams.n_ctx,
cparams.n_seq_max,
cparams.n_seq_virt,
cparams.n_ubatch,
padding);
} else {
@@ -14513,6 +14514,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.offload_kqv,
cparams.n_ctx,
cparams.n_seq_max,
cparams.n_seq_virt,
padding,
hparams.n_swa,
hparams.swa_type);

View File

@@ -61,7 +61,7 @@ int main(int argc, char ** argv) {
const int32_t n_kv_max = llama_n_ctx(ctx);
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
llama_batch batch = llama_batch_init(n_kv_max*8, 0, 1); // TODO: tmp!!!
// decode in batches of ctx_params.n_batch tokens
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
@@ -119,9 +119,9 @@ int main(int argc, char ** argv) {
const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
if (n_ctx_req > n_kv_max) {
continue;
}
//if (n_ctx_req > n_kv_max) {
// continue;
//}
common_batch_clear(batch);