mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-16 16:27:32 +03:00
Compare commits
8 Commits
b8047
...
gg/llama-h
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dfceb012ee | ||
|
|
30b4d4e1b3 | ||
|
|
5495ea96ba | ||
|
|
f3da97e61b | ||
|
|
a70293bc25 | ||
|
|
2ac5be3a58 | ||
|
|
ac8f3474c8 | ||
|
|
cd811b7a9d |
@@ -235,7 +235,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
|
||||
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
|
||||
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
|
||||
llama_batch batch = llama_batch_init(n_ctx*n_clients, 0, 1);
|
||||
|
||||
int32_t n_total_prompt = 0;
|
||||
int32_t n_total_gen = 0;
|
||||
@@ -289,6 +289,7 @@ int main(int argc, char ** argv) {
|
||||
// all sequences have ended - clear the entire KV cache
|
||||
for (int i = 1; i <= n_clients; ++i) {
|
||||
llama_memory_seq_rm(mem, i, -1, -1);
|
||||
|
||||
// but keep the system prompt
|
||||
llama_memory_seq_cp(mem, 0, i, -1, -1);
|
||||
}
|
||||
|
||||
@@ -2086,6 +2086,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
return false;
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
// TODO: add support
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
|
||||
return false;
|
||||
} break;
|
||||
case GGML_OP_CPY: {
|
||||
ggml_tensor *src = op->src[0];
|
||||
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
|
||||
|
||||
@@ -2222,6 +2222,12 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
// TODO: add support
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
|
||||
return false;
|
||||
} break;
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CONT:
|
||||
|
||||
@@ -4285,6 +4285,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return false;
|
||||
}
|
||||
}
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
// TODO: add support
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
|
||||
return false;
|
||||
} break;
|
||||
case GGML_OP_CPY:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
|
||||
@@ -10333,6 +10333,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
return false;
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
// TODO: add support
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
|
||||
return false;
|
||||
} break;
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_DUP:
|
||||
|
||||
@@ -4696,7 +4696,6 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||
|
||||
if (mask) {
|
||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||
GGML_ASSERT(mask->ne[2] == q->ne[3]);
|
||||
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
|
||||
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
|
||||
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
|
||||
|
||||
@@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
|
||||
|
||||
// note: tracking the other way around is not necessary for now
|
||||
//seq_cpl[s0][s1] = true;
|
||||
|
||||
has_cpl = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -405,6 +407,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
|
||||
return n_outputs;
|
||||
}
|
||||
|
||||
uint32_t llama_batch_allocr::get_n_used() const {
|
||||
return n_used;
|
||||
}
|
||||
|
||||
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
|
||||
return out_ids;
|
||||
}
|
||||
@@ -420,6 +426,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
|
||||
void llama_batch_allocr::split_reset() {
|
||||
out_ids.clear();
|
||||
|
||||
n_used = 0;
|
||||
|
||||
used.clear();
|
||||
used.resize(get_n_tokens(), false);
|
||||
|
||||
@@ -444,6 +452,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
||||
idxs.push_back(cur_idx);
|
||||
|
||||
used[cur_idx] = true;
|
||||
++n_used;
|
||||
|
||||
++cur_idx;
|
||||
|
||||
@@ -459,9 +468,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
||||
return ubatch_add(idxs, idxs.size(), false);
|
||||
}
|
||||
|
||||
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
||||
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
|
||||
if (sequential && has_cpl) {
|
||||
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<seq_set_t> cur_seq_set;
|
||||
|
||||
llama_seq_id last_seq_id = -1;
|
||||
|
||||
// determine the non-overlapping sequence sets participating in this ubatch
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
if (used[i]) {
|
||||
@@ -478,9 +495,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
// accept only increasing sequence ids
|
||||
if (sequential) {
|
||||
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
|
||||
}
|
||||
|
||||
if (add) {
|
||||
cur_seq_set.push_back(seq_set[i]);
|
||||
|
||||
last_seq_id = batch.seq_id[i][0];
|
||||
|
||||
if (cur_seq_set.size() > n_ubatch) {
|
||||
break;
|
||||
}
|
||||
@@ -529,6 +553,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
||||
idxs_per_seq[s].push_back(idx);
|
||||
|
||||
used[idx] = true;
|
||||
++n_used;
|
||||
|
||||
++cur_idx[s];
|
||||
}
|
||||
@@ -570,6 +595,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
|
||||
idxs.push_back(cur_idx);
|
||||
|
||||
used[cur_idx] = true;
|
||||
++n_used;
|
||||
|
||||
if (idxs.size() >= n_ubatch) {
|
||||
break;
|
||||
|
||||
@@ -54,6 +54,7 @@ public:
|
||||
|
||||
uint32_t get_n_tokens() const;
|
||||
uint32_t get_n_outputs() const;
|
||||
uint32_t get_n_used() const;
|
||||
|
||||
// the array of output indices in the order they were encountered during the ubatch splitting
|
||||
std::vector<int32_t> & get_out_ids();
|
||||
@@ -69,7 +70,8 @@ public:
|
||||
llama_ubatch split_simple(uint32_t n_ubatch);
|
||||
|
||||
// make ubatches of equal-length sequences sets
|
||||
llama_ubatch split_equal(uint32_t n_ubatch);
|
||||
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
|
||||
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
|
||||
|
||||
// sequence-set-wise split - each ubatch contains a single sequence-set
|
||||
llama_ubatch split_seq(uint32_t n_ubatch);
|
||||
@@ -112,6 +114,9 @@ private:
|
||||
using pos_set_t = std::set<llama_pos>;
|
||||
using seq_cpl_t = std::vector<bool>;
|
||||
|
||||
// helper flag to quickly determine if there are any coupled sequences in the batch
|
||||
bool has_cpl;
|
||||
|
||||
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
||||
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
||||
|
||||
@@ -125,6 +130,8 @@ private:
|
||||
// batch indices of the output
|
||||
std::vector<int32_t> out_ids;
|
||||
|
||||
uint32_t n_used;
|
||||
|
||||
// used[i] indicates if token i has already been used in a previous ubatch
|
||||
std::vector<bool> used;
|
||||
|
||||
|
||||
@@ -33,6 +33,9 @@ llama_context::llama_context(
|
||||
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
|
||||
}
|
||||
|
||||
const char * LLAMA_HT = getenv("LLAMA_HT");
|
||||
cparams.n_seq_virt = LLAMA_HT ? cparams.n_seq_max : 1;
|
||||
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch;
|
||||
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
||||
@@ -1308,7 +1311,8 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
||||
this->n_outputs = n_outputs;
|
||||
|
||||
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
|
||||
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
||||
//llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
||||
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens, 1);
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
|
||||
|
||||
@@ -11,8 +11,9 @@ struct llama_cparams {
|
||||
uint32_t n_batch;
|
||||
uint32_t n_ubatch;
|
||||
uint32_t n_seq_max;
|
||||
int n_threads; // number of threads to use for generation
|
||||
int n_threads_batch; // number of threads to use for batch processing
|
||||
uint32_t n_seq_virt;
|
||||
int32_t n_threads; // number of threads to use for generation
|
||||
int32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
float rope_freq_base;
|
||||
float rope_freq_scale;
|
||||
|
||||
@@ -281,19 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_kq_mask) {
|
||||
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
||||
mctx->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
|
||||
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_kq_mask) {
|
||||
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
||||
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
|
||||
if (self_kq_mask_swa) {
|
||||
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
}
|
||||
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
|
||||
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
|
||||
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
||||
|
||||
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||
@@ -333,9 +336,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
|
||||
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_kq_mask) {
|
||||
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch);
|
||||
mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
|
||||
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
|
||||
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
||||
|
||||
@@ -350,7 +354,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_one::set_input(const llama_ubatch *) {
|
||||
void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
|
||||
GGML_UNUSED(ubatch);
|
||||
GGML_ASSERT(one && ggml_nelements(one) == 1);
|
||||
float f_one = 1.0f;
|
||||
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
|
||||
@@ -995,10 +1000,13 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
||||
|
||||
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
||||
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
||||
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
@@ -1025,6 +1033,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
float kq_scale) const {
|
||||
const bool v_trans = v->nb[1] > v->nb[2];
|
||||
|
||||
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
|
||||
|
||||
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_seqs, n_seqs);
|
||||
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
||||
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
||||
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
||||
@@ -1073,7 +1085,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
#endif
|
||||
}
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs);
|
||||
} else {
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
|
||||
@@ -1118,7 +1130,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
|
||||
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs);
|
||||
|
||||
if (!cparams.offload_kqv) {
|
||||
// all nodes between the KV store and the attention output are run on the CPU
|
||||
@@ -1196,10 +1208,13 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
||||
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
@@ -1230,8 +1245,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
||||
const auto & k_idxs = inp->get_k_idxs();
|
||||
const auto & v_idxs = inp->get_v_idxs();
|
||||
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
@@ -1290,11 +1308,15 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
|
||||
// optionally store to KV cache
|
||||
if (k_cur) {
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
||||
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
|
||||
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
||||
}
|
||||
|
||||
if (v_cur) {
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
||||
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
|
||||
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||
@@ -1398,8 +1420,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
||||
const auto & k_idxs = inp->get_k_idxs();
|
||||
const auto & v_idxs = inp->get_v_idxs();
|
||||
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
@@ -1431,11 +1456,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
||||
|
||||
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
|
||||
|
||||
{
|
||||
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
@@ -1446,8 +1475,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
||||
|
||||
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
|
||||
@@ -249,10 +249,16 @@ public:
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
||||
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
@@ -274,13 +280,23 @@ public:
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
||||
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
||||
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
|
||||
ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
|
||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
@@ -319,8 +335,14 @@ public:
|
||||
|
||||
ggml_tensor * s_copy; // I32 [kv_size]
|
||||
|
||||
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
||||
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||
|
||||
@@ -336,7 +358,7 @@ public:
|
||||
llm_graph_input_one() {}
|
||||
virtual ~llm_graph_input_one() = default;
|
||||
|
||||
void set_input(const llama_ubatch *) override;
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * one = nullptr; // F32
|
||||
};
|
||||
|
||||
@@ -20,14 +20,15 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
||||
bool swa_full,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_seq_virt,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad) : hparams(model.hparams) {
|
||||
uint32_t n_pad) : hparams(model.hparams), n_seq_virt(n_seq_virt) {
|
||||
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
||||
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
||||
|
||||
const uint32_t size_base = kv_size;
|
||||
|
||||
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
|
||||
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(n_seq_max/n_seq_virt) + n_ubatch, n_pad));
|
||||
|
||||
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
||||
if (swa_full) {
|
||||
@@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
||||
|
||||
kv_base = std::make_unique<llama_kv_cache_unified>(
|
||||
model, std::move(filter_base), type_k, type_v,
|
||||
v_trans, offload, size_base, n_seq_max, n_pad,
|
||||
v_trans, offload, size_base, n_seq_max, n_seq_virt, n_pad,
|
||||
0, LLAMA_SWA_TYPE_NONE);
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
||||
|
||||
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
||||
model, std::move(filter_swa), type_k, type_v,
|
||||
v_trans, offload, size_swa, n_seq_max, n_pad,
|
||||
v_trans, offload, size_swa, n_seq_max, n_seq_virt, n_pad,
|
||||
hparams.n_swa, hparams.swa_type);
|
||||
}
|
||||
|
||||
@@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
||||
|
||||
// first try simple split
|
||||
do {
|
||||
if (n_seq_virt > 1) {
|
||||
// requires equal splits, so we skip the simple split
|
||||
break;
|
||||
}
|
||||
|
||||
balloc.split_reset();
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
@@ -113,20 +119,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
auto heads_base = kv_base->prepare(ubatches);
|
||||
if (heads_base.empty()) {
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
auto heads_swa = kv_swa->prepare(ubatches);
|
||||
if (heads_swa.empty()) {
|
||||
auto sinfos_base = kv_base->prepare(ubatches);
|
||||
if (sinfos_base.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
assert(heads_base.size() == heads_swa.size());
|
||||
auto sinfos_swa = kv_swa->prepare(ubatches);
|
||||
if (sinfos_swa.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
assert(sinfos_base.size() == sinfos_swa.size());
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
||||
} while (false);
|
||||
|
||||
// if it fails, try equal split
|
||||
@@ -135,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (true) {
|
||||
auto ubatch = balloc.split_equal(n_ubatch);
|
||||
auto ubatch = balloc.split_equal(n_ubatch, n_seq_virt > 1);
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
@@ -144,20 +155,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
auto heads_base = kv_base->prepare(ubatches);
|
||||
if (heads_base.empty()) {
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
auto heads_swa = kv_swa->prepare(ubatches);
|
||||
if (heads_swa.empty()) {
|
||||
auto sinfos_base = kv_base->prepare(ubatches);
|
||||
if (sinfos_base.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
assert(heads_base.size() == heads_swa.size());
|
||||
auto sinfos_swa = kv_swa->prepare(ubatches);
|
||||
if (sinfos_swa.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
assert(sinfos_base.size() == sinfos_swa.size());
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
||||
} while (false);
|
||||
|
||||
// TODO: if we fail again, we should attempt different splitting strategies
|
||||
@@ -220,13 +236,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||
|
||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
std::vector<uint32_t> heads_base,
|
||||
std::vector<uint32_t> heads_swa,
|
||||
slot_info_vec_t sinfos_base,
|
||||
slot_info_vec_t sinfos_swa,
|
||||
std::vector<llama_ubatch> ubatches) :
|
||||
ubatches(std::move(ubatches)),
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
|
||||
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
|
||||
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
|
||||
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
|
||||
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ public:
|
||||
bool swa_full,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_seq_virt,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad);
|
||||
|
||||
@@ -68,12 +69,16 @@ public:
|
||||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
const uint32_t n_seq_virt = 1;
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
||||
};
|
||||
|
||||
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
|
||||
public:
|
||||
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||
|
||||
// used for errors
|
||||
llama_kv_cache_unified_iswa_context(llama_memory_status status);
|
||||
|
||||
@@ -90,8 +95,8 @@ public:
|
||||
// used to create a batch processing context from a batch
|
||||
llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
std::vector<uint32_t> heads_base,
|
||||
std::vector<uint32_t> heads_swa,
|
||||
slot_info_vec_t sinfos_base,
|
||||
slot_info_vec_t sinfos_swa,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_unified_iswa_context();
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -24,8 +24,6 @@ public:
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
using ubatch_heads = std::vector<uint32_t>;
|
||||
|
||||
struct defrag_info {
|
||||
bool empty() const {
|
||||
return ids.empty();
|
||||
@@ -37,6 +35,50 @@ public:
|
||||
std::vector<uint32_t> ids;
|
||||
};
|
||||
|
||||
// for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
|
||||
// KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
|
||||
struct slot_info {
|
||||
// data for ggml_set_rows
|
||||
using idx_vec_t = std::vector<uint32_t>;
|
||||
|
||||
llama_seq_id s0;
|
||||
llama_seq_id s1;
|
||||
|
||||
std::vector<llama_seq_id> seq_id_virt;
|
||||
std::vector<idx_vec_t> idxs;
|
||||
|
||||
uint32_t head() const {
|
||||
GGML_ASSERT(idxs.size() == 1);
|
||||
|
||||
return idxs.at(0).at(0);
|
||||
}
|
||||
|
||||
void resize(size_t n) {
|
||||
seq_id_virt.resize(n);
|
||||
idxs.resize(n);
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
GGML_ASSERT(idxs.size() == seq_id_virt.size());
|
||||
|
||||
return idxs.at(0).size();
|
||||
}
|
||||
|
||||
size_t n_seq_virt() const {
|
||||
return seq_id_virt.size();
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
return idxs.empty();
|
||||
}
|
||||
|
||||
void clear() {
|
||||
idxs.clear();
|
||||
}
|
||||
};
|
||||
|
||||
using slot_info_vec_t = std::vector<slot_info>;
|
||||
|
||||
llama_kv_cache_unified(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
@@ -46,6 +88,7 @@ public:
|
||||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_seq_virt,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type);
|
||||
@@ -98,36 +141,44 @@ public:
|
||||
uint32_t get_n_kv() const;
|
||||
|
||||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
|
||||
|
||||
//
|
||||
// preparation API
|
||||
//
|
||||
|
||||
// find places for the provided ubatches in the cache, returns the head locations
|
||||
// find places for the provided ubatches in the cache, returns the slot infos
|
||||
// return empty vector on failure
|
||||
ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
||||
|
||||
// return the cell position where we can insert the ubatch
|
||||
// return -1 on failure to find a contiguous slot of kv cells
|
||||
int32_t find_slot(const llama_ubatch & ubatch) const;
|
||||
// find a slot of kv cells that can hold the ubatch
|
||||
// if cont == true, then the slot must be continuous
|
||||
// return empty slot_info on failure
|
||||
slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
|
||||
|
||||
// emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
|
||||
void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);
|
||||
// emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
|
||||
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
|
||||
|
||||
//
|
||||
// set_input API
|
||||
// input API
|
||||
//
|
||||
|
||||
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
|
||||
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||
|
||||
void set_input_k_shift(ggml_tensor * dst) const;
|
||||
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_k_shift (ggml_tensor * dst) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
private:
|
||||
@@ -141,15 +192,15 @@ private:
|
||||
|
||||
ggml_tensor * k;
|
||||
ggml_tensor * v;
|
||||
|
||||
std::vector<ggml_tensor *> k_seq;
|
||||
std::vector<ggml_tensor *> v_seq;
|
||||
};
|
||||
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
uint32_t head = 0;
|
||||
|
||||
const uint32_t n_seq_max = 1;
|
||||
const uint32_t n_seq_max = 1;
|
||||
const uint32_t n_seq_virt = 1;
|
||||
|
||||
// required padding
|
||||
const uint32_t n_pad = 1;
|
||||
@@ -157,14 +208,26 @@ private:
|
||||
// SWA
|
||||
const uint32_t n_swa = 0;
|
||||
|
||||
// env: LLAMA_KV_CACHE_DEBUG
|
||||
int debug = 0;
|
||||
|
||||
// env: LLAMA_SET_ROWS (temporary)
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
|
||||
int supports_set_rows = false;
|
||||
|
||||
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
llama_kv_cells_unified cells;
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
std::vector<uint32_t> v_heads;
|
||||
|
||||
std::vector<llama_kv_cells_unified> v_cells;
|
||||
|
||||
// maps from a sequence id to a virtual sequence id
|
||||
std::vector<uint32_t> seq_virt_idx;
|
||||
|
||||
std::vector<kv_layer> layers;
|
||||
|
||||
@@ -211,8 +274,8 @@ private:
|
||||
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
||||
public:
|
||||
// some shorthands
|
||||
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||
|
||||
// used for errors
|
||||
llama_kv_cache_unified_context(llama_memory_status status);
|
||||
@@ -231,7 +294,7 @@ public:
|
||||
// used to create a batch procesing context from a batch
|
||||
llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv,
|
||||
ubatch_heads heads,
|
||||
slot_info_vec_t sinfos,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_unified_context();
|
||||
@@ -257,11 +320,16 @@ public:
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
|
||||
|
||||
void set_input_k_shift(ggml_tensor * dst) const;
|
||||
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
|
||||
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
void set_input_k_shift (ggml_tensor * dst) const;
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
@@ -283,10 +351,10 @@ private:
|
||||
// batch processing context
|
||||
//
|
||||
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
// the index of the cur ubatch to process
|
||||
size_t i_cur = 0;
|
||||
|
||||
ubatch_heads heads;
|
||||
slot_info_vec_t sinfos;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
@@ -297,7 +365,4 @@ private:
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// as the cache gets filled, the benefit from this heuristic disappears
|
||||
int32_t n_kv;
|
||||
|
||||
// the beginning of the current slot in which the ubatch will be inserted
|
||||
int32_t head;
|
||||
};
|
||||
|
||||
@@ -105,10 +105,30 @@ public:
|
||||
res.resize(n);
|
||||
|
||||
for (uint32_t j = 0; j < n; ++j) {
|
||||
res.pos[j] = pos[i + j];
|
||||
res.seq[j] = seq[i + j];
|
||||
const auto idx = i + j;
|
||||
|
||||
assert(shift[i + j] == 0);
|
||||
res.pos[j] = pos[idx];
|
||||
res.seq[j] = seq[idx];
|
||||
|
||||
assert(shift[idx] == 0);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
|
||||
llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
|
||||
llama_kv_cells_unified res;
|
||||
|
||||
res.resize(idxs.size());
|
||||
|
||||
for (uint32_t j = 0; j < idxs.size(); ++j) {
|
||||
const auto idx = idxs[j];
|
||||
|
||||
res.pos[j] = pos[idx];
|
||||
res.seq[j] = seq[idx];
|
||||
|
||||
assert(shift[idx] == 0);
|
||||
}
|
||||
|
||||
return res;
|
||||
@@ -119,26 +139,58 @@ public:
|
||||
assert(i + other.pos.size() <= pos.size());
|
||||
|
||||
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
||||
if (pos[i + j] == -1 && other.pos[j] != -1) {
|
||||
const auto idx = i + j;
|
||||
|
||||
if (pos[idx] == -1 && other.pos[j] != -1) {
|
||||
used.insert(i + j);
|
||||
}
|
||||
|
||||
if (pos[i + j] != -1 && other.pos[j] == -1) {
|
||||
if (pos[idx] != -1 && other.pos[j] == -1) {
|
||||
used.erase(i + j);
|
||||
}
|
||||
|
||||
if (pos[i + j] != -1) {
|
||||
if (pos[idx] != -1) {
|
||||
seq_pos_rm(i + j);
|
||||
}
|
||||
|
||||
pos[i + j] = other.pos[j];
|
||||
seq[i + j] = other.seq[j];
|
||||
pos[idx] = other.pos[j];
|
||||
seq[idx] = other.seq[j];
|
||||
|
||||
if (pos[i + j] != -1) {
|
||||
if (pos[idx] != -1) {
|
||||
seq_pos_add(i + j);
|
||||
}
|
||||
|
||||
assert(shift[i + j] == 0);
|
||||
assert(shift[idx] == 0);
|
||||
}
|
||||
}
|
||||
|
||||
// set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
|
||||
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
|
||||
assert(idxs.size() == other.pos.size());
|
||||
|
||||
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
||||
const auto idx = idxs[j];
|
||||
|
||||
if (pos[idx] == -1 && other.pos[j] != -1) {
|
||||
used.insert(idx);
|
||||
}
|
||||
|
||||
if (pos[idx] != -1 && other.pos[j] == -1) {
|
||||
used.erase(idx);
|
||||
}
|
||||
|
||||
if (pos[idx] != -1) {
|
||||
seq_pos_rm(idx);
|
||||
}
|
||||
|
||||
pos[idx] = other.pos[j];
|
||||
seq[idx] = other.seq[j];
|
||||
|
||||
if (pos[idx] != -1) {
|
||||
seq_pos_add(idx);
|
||||
}
|
||||
|
||||
assert(shift[idx] == 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ llama_memory_hybrid::llama_memory_hybrid(
|
||||
offload,
|
||||
kv_size,
|
||||
n_seq_max,
|
||||
1,
|
||||
n_pad,
|
||||
n_swa,
|
||||
swa_type
|
||||
@@ -70,7 +71,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
|
||||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
ubatch = balloc.split_equal(n_ubatch);
|
||||
ubatch = balloc.split_equal(n_ubatch, false);
|
||||
}
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
@@ -80,6 +81,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
// prepare the recurrent batches first
|
||||
if (!mem_recr->prepare(ubatches)) {
|
||||
// TODO: will the recurrent cache be in an undefined context at this point?
|
||||
@@ -195,11 +201,11 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
|
||||
|
||||
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
||||
llama_memory_hybrid * mem,
|
||||
std::vector<uint32_t> heads_attn,
|
||||
slot_info_vec_t sinfos_attn,
|
||||
std::vector<llama_ubatch> ubatches) :
|
||||
ubatches(std::move(ubatches)),
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
|
||||
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
|
||||
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||
}
|
||||
|
||||
@@ -92,6 +92,8 @@ private:
|
||||
|
||||
class llama_memory_hybrid_context : public llama_memory_context_i {
|
||||
public:
|
||||
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||
|
||||
// init failure
|
||||
explicit llama_memory_hybrid_context(llama_memory_status status);
|
||||
|
||||
@@ -107,7 +109,7 @@ public:
|
||||
// init success
|
||||
llama_memory_hybrid_context(
|
||||
llama_memory_hybrid * mem,
|
||||
std::vector<uint32_t> heads_attn,
|
||||
slot_info_vec_t sinfos_attn,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
~llama_memory_hybrid_context() = default;
|
||||
|
||||
@@ -374,10 +374,11 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
|
||||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
ubatch = balloc.split_equal(n_ubatch);
|
||||
ubatch = balloc.split_equal(n_ubatch, false);
|
||||
}
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
@@ -14499,6 +14499,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
params.swa_full,
|
||||
cparams.n_ctx,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_seq_virt,
|
||||
cparams.n_ubatch,
|
||||
padding);
|
||||
} else {
|
||||
@@ -14513,6 +14514,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
cparams.offload_kqv,
|
||||
cparams.n_ctx,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_seq_virt,
|
||||
padding,
|
||||
hparams.n_swa,
|
||||
hparams.swa_type);
|
||||
|
||||
@@ -61,7 +61,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const int32_t n_kv_max = llama_n_ctx(ctx);
|
||||
|
||||
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
|
||||
llama_batch batch = llama_batch_init(n_kv_max*8, 0, 1); // TODO: tmp!!!
|
||||
|
||||
// decode in batches of ctx_params.n_batch tokens
|
||||
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
|
||||
@@ -119,9 +119,9 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
|
||||
|
||||
if (n_ctx_req > n_kv_max) {
|
||||
continue;
|
||||
}
|
||||
//if (n_ctx_req > n_kv_max) {
|
||||
// continue;
|
||||
//}
|
||||
|
||||
common_batch_clear(batch);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user