mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-26 14:23:22 +02:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0fc16b42e8 | ||
|
|
053b1539c0 | ||
|
|
b3a89c3d9e | ||
|
|
e15898d1c7 | ||
|
|
803f8baf4f | ||
|
|
3600cc2886 | ||
|
|
c7e0a2054b | ||
|
|
3f55f781f1 |
@@ -1348,9 +1348,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"--prio"}, "N",
|
||||
string_format("set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.cpuparams.priority),
|
||||
string_format("set process/thread priority : low(-1), normal(0), medium(1), high(2), realtime(3) (default: %d)\n", params.cpuparams.priority),
|
||||
[](common_params & params, int prio) {
|
||||
if (prio < 0 || prio > 3) {
|
||||
if (prio < GGML_SCHED_PRIO_LOW || prio > GGML_SCHED_PRIO_REALTIME) {
|
||||
throw std::invalid_argument("invalid value");
|
||||
}
|
||||
params.cpuparams.priority = (enum ggml_sched_priority) prio;
|
||||
|
||||
@@ -154,9 +154,10 @@ bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think
|
||||
if (!rest.empty()) {
|
||||
handle_reasoning(rest, /* closed */ !is_partial());
|
||||
}
|
||||
if (!syntax_.thinking_forced_open) {
|
||||
throw common_chat_msg_partial_exception(end_think);
|
||||
}
|
||||
// Allow unclosed thinking tags, for now (https://github.com/ggml-org/llama.cpp/issues/13812, https://github.com/ggml-org/llama.cpp/issues/13877)
|
||||
// if (!syntax_.thinking_forced_open) {
|
||||
// throw common_chat_msg_partial_exception(end_think);
|
||||
// }
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,6 +203,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
|
||||
|
||||
DWORD p = NORMAL_PRIORITY_CLASS;
|
||||
switch (prio) {
|
||||
case GGML_SCHED_PRIO_LOW: p = BELOW_NORMAL_PRIORITY_CLASS; break;
|
||||
case GGML_SCHED_PRIO_NORMAL: p = NORMAL_PRIORITY_CLASS; break;
|
||||
case GGML_SCHED_PRIO_MEDIUM: p = ABOVE_NORMAL_PRIORITY_CLASS; break;
|
||||
case GGML_SCHED_PRIO_HIGH: p = HIGH_PRIORITY_CLASS; break;
|
||||
@@ -228,6 +229,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
|
||||
|
||||
int p = 0;
|
||||
switch (prio) {
|
||||
case GGML_SCHED_PRIO_LOW: p = 5; break;
|
||||
case GGML_SCHED_PRIO_NORMAL: p = 0; break;
|
||||
case GGML_SCHED_PRIO_MEDIUM: p = -5; break;
|
||||
case GGML_SCHED_PRIO_HIGH: p = -10; break;
|
||||
|
||||
@@ -63,6 +63,7 @@ cmake --build build --config Release
|
||||
cmake --preset x64-windows-llvm-release
|
||||
cmake --build build-x64-windows-llvm-release
|
||||
```
|
||||
- Curl usage is enabled by default and can be turned off with `-DLLAMA_CURL=OFF`. Otherwise you need to install development libraries for libcurl.
|
||||
|
||||
## BLAS Build
|
||||
|
||||
|
||||
@@ -392,7 +392,7 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
LOG_ERR("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
|
||||
LOG_WRN("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
|
||||
|
||||
n_cache_miss += 1;
|
||||
|
||||
|
||||
@@ -133,9 +133,8 @@ int main(int argc, char ** argv) {
|
||||
const int ib = i/n_batch - 1;
|
||||
const int bd = n_batch_grp*(n_grp - 1);
|
||||
|
||||
llama_kv_self_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
|
||||
llama_kv_self_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
|
||||
llama_kv_self_update (ctx);
|
||||
llama_kv_self_seq_add(ctx, 0, n_past - n_batch, n_past, ib*bd);
|
||||
llama_kv_self_seq_div(ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
|
||||
|
||||
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
|
||||
}
|
||||
@@ -169,8 +168,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||
llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||
//llama_kv_self_defrag (ctx);
|
||||
llama_kv_self_update (ctx);
|
||||
|
||||
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
|
||||
|
||||
@@ -200,8 +197,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||
llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||
//llama_kv_self_defrag (ctx);
|
||||
llama_kv_self_update (ctx);
|
||||
|
||||
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
|
||||
}
|
||||
|
||||
@@ -2181,6 +2181,7 @@ extern "C" {
|
||||
|
||||
// scheduling priorities
|
||||
enum ggml_sched_priority {
|
||||
GGML_SCHED_PRIO_LOW = -1,
|
||||
GGML_SCHED_PRIO_NORMAL,
|
||||
GGML_SCHED_PRIO_MEDIUM,
|
||||
GGML_SCHED_PRIO_HIGH,
|
||||
|
||||
@@ -2418,12 +2418,32 @@ static bool ggml_thread_apply_priority(int32_t prio) {
|
||||
// This is up to the applications.
|
||||
DWORD p = THREAD_PRIORITY_NORMAL;
|
||||
switch (prio) {
|
||||
case GGML_SCHED_PRIO_LOW: p = THREAD_PRIORITY_BELOW_NORMAL; break;
|
||||
case GGML_SCHED_PRIO_NORMAL: p = THREAD_PRIORITY_NORMAL; break;
|
||||
case GGML_SCHED_PRIO_MEDIUM: p = THREAD_PRIORITY_ABOVE_NORMAL; break;
|
||||
case GGML_SCHED_PRIO_HIGH: p = THREAD_PRIORITY_HIGHEST; break;
|
||||
case GGML_SCHED_PRIO_REALTIME: p = THREAD_PRIORITY_TIME_CRITICAL; break;
|
||||
}
|
||||
|
||||
if (prio != GGML_SCHED_PRIO_LOW) {
|
||||
// Tell Windows that this thread should not be throttled (needs its own CPU core).
|
||||
// Newer Windows 11 versions aggresively park (offline) CPU cores and often place
|
||||
// all our threads onto the first 4 cores which results in terrible performance with
|
||||
// n_threads > 4
|
||||
#if _WIN32_WINNT >= 0x0602
|
||||
THREAD_POWER_THROTTLING_STATE t;
|
||||
ZeroMemory(&t, sizeof(t));
|
||||
t.Version = THREAD_POWER_THROTTLING_CURRENT_VERSION;
|
||||
t.ControlMask = THREAD_POWER_THROTTLING_EXECUTION_SPEED;
|
||||
t.StateMask = 0;
|
||||
|
||||
if (!SetThreadInformation(GetCurrentThread(), ThreadPowerThrottling, &t, sizeof(t))) {
|
||||
GGML_LOG_DEBUG("failed to disable thread power throttling %d : (%d)\n", prio, (int) GetLastError());
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
if (prio == GGML_SCHED_PRIO_NORMAL) {
|
||||
// Keep inherited policy/priority
|
||||
return true;
|
||||
@@ -2451,6 +2471,8 @@ static bool ggml_thread_apply_priority(int32_t prio) {
|
||||
struct sched_param p;
|
||||
int32_t policy = SCHED_OTHER;
|
||||
switch (prio) {
|
||||
// TODO: there seems to be no way to set lower prio on Apple platforms
|
||||
case GGML_SCHED_PRIO_LOW: policy = SCHED_OTHER; p.sched_priority = 0; break;
|
||||
case GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break;
|
||||
case GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break;
|
||||
case GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break;
|
||||
@@ -2507,6 +2529,7 @@ static bool ggml_thread_apply_priority(int32_t prio) {
|
||||
struct sched_param p;
|
||||
int32_t policy = SCHED_OTHER;
|
||||
switch (prio) {
|
||||
case GGML_SCHED_PRIO_LOW: policy = SCHED_BATCH; p.sched_priority = 0; break;
|
||||
case GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break;
|
||||
case GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break;
|
||||
case GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break;
|
||||
|
||||
@@ -366,6 +366,8 @@ extern "C" {
|
||||
bool no_perf; // measure performance timings
|
||||
bool op_offload; // offload host tensor operations to device
|
||||
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
|
||||
};
|
||||
|
||||
// model quantization parameters
|
||||
@@ -502,6 +504,7 @@ extern "C" {
|
||||
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
|
||||
|
||||
// Get the model's RoPE frequency scaling factor
|
||||
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
||||
@@ -652,7 +655,6 @@ extern "C" {
|
||||
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||
// - lazily on next llama_decode()
|
||||
// - explicitly with llama_kv_self_update()
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_kv_self_seq_add(
|
||||
@@ -665,7 +667,6 @@ extern "C" {
|
||||
// Integer division of the positions by factor of `d > 1`
|
||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||
// - lazily on next llama_decode()
|
||||
// - explicitly with llama_kv_self_update()
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_kv_self_seq_div(
|
||||
@@ -693,16 +694,15 @@ extern "C" {
|
||||
// Defragment the KV cache
|
||||
// This will be applied:
|
||||
// - lazily on next llama_decode()
|
||||
// - explicitly with llama_kv_self_update()
|
||||
// TODO: deprecate and always update the cache lazily [TAG: API_KV_NO_DEFRAG]
|
||||
LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
|
||||
LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx),
|
||||
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
|
||||
|
||||
// Check if the context supports KV cache shifting
|
||||
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
|
||||
|
||||
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
||||
// TODO: deprecate and always update the cache lazily [TAG: API_KV_NO_DEFRAG]
|
||||
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
|
||||
LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx),
|
||||
"simply remove this call, updates are applied lazily on the next llama_decode()");
|
||||
|
||||
//
|
||||
// State / sessions
|
||||
|
||||
@@ -21,6 +21,9 @@ add_library(llama
|
||||
llama-impl.cpp
|
||||
llama-io.cpp
|
||||
llama-kv-cache.cpp
|
||||
llama-kv-cache-unified.cpp
|
||||
llama-kv-cache-unified-iswa.cpp
|
||||
llama-kv-cache-recurrent.cpp
|
||||
llama-memory.cpp
|
||||
llama-mmap.cpp
|
||||
llama-model-loader.cpp
|
||||
|
||||
@@ -123,6 +123,11 @@ llama_context::llama_context(
|
||||
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
||||
}
|
||||
|
||||
if (!params.swa_full && cparams.n_seq_max > 1) {
|
||||
LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
|
||||
__func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
|
||||
}
|
||||
|
||||
if (!hparams.vocab_only) {
|
||||
// GPU backends
|
||||
for (auto * dev : model.devices) {
|
||||
@@ -424,28 +429,33 @@ const llama_kv_cache * llama_context::get_kv_self() const {
|
||||
return kv_self;
|
||||
}
|
||||
|
||||
void llama_context::kv_self_update() {
|
||||
bool llama_context::kv_self_update() {
|
||||
if (!memory) {
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
if (kv_self->update(*this)) {
|
||||
// if the KV cache did any computation, we have to reserve a new worst-case graph
|
||||
const auto kv_state = kv_self->init_full();
|
||||
if (!kv_state) {
|
||||
throw std::runtime_error("failed to initialize KV cache");
|
||||
}
|
||||
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
||||
if (!gf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
|
||||
}
|
||||
if (!kv_self->update(*this)) {
|
||||
// no updates have been performed
|
||||
return false;
|
||||
}
|
||||
|
||||
// if the KV cache did any computation, we have to reserve a new worst-case graph
|
||||
const auto kv_state = kv_self->init_full();
|
||||
if (!kv_state) {
|
||||
throw std::runtime_error("failed to initialize KV cache");
|
||||
}
|
||||
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
||||
if (!gf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
enum llama_pooling_type llama_context::pooling_type() const {
|
||||
@@ -933,24 +943,44 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
// handle any pending defrags/shifts
|
||||
kv_self_update();
|
||||
|
||||
auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
|
||||
if (!kv_state) {
|
||||
return -2;
|
||||
}
|
||||
llama_memory_state_ptr kv_state;
|
||||
|
||||
switch (kv_state->get_status()) {
|
||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||
{
|
||||
} break;
|
||||
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
||||
{
|
||||
// not a fatal error, we can re-try with a different batch
|
||||
return 1;
|
||||
}
|
||||
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
||||
{
|
||||
return -2;
|
||||
}
|
||||
bool did_defrag = false;
|
||||
|
||||
while (true) {
|
||||
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
|
||||
if (!kv_state) {
|
||||
return -2;
|
||||
}
|
||||
|
||||
switch (kv_state->get_status()) {
|
||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||
{
|
||||
} break;
|
||||
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
||||
{
|
||||
if (!did_defrag) {
|
||||
did_defrag = true;
|
||||
|
||||
kv_self->defrag_sched(-1.0f);
|
||||
if (kv_self_update()) {
|
||||
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
|
||||
|
||||
return 1;
|
||||
}
|
||||
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
||||
{
|
||||
return -2;
|
||||
}
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
// reserve output buffer
|
||||
@@ -2251,6 +2281,7 @@ llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
||||
return ctx->get_kv_self();
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_update(llama_context * ctx) {
|
||||
ctx->kv_self_update();
|
||||
}
|
||||
@@ -2505,6 +2536,7 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
||||
return kv->seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_defrag(llama_context * ctx) {
|
||||
auto * kv = ctx->get_kv_self();
|
||||
if (!kv) {
|
||||
@@ -2646,22 +2678,8 @@ int32_t llama_encode(
|
||||
int32_t llama_decode(
|
||||
llama_context * ctx,
|
||||
llama_batch batch) {
|
||||
int ret = ctx->decode(batch);
|
||||
|
||||
// defrag and try again
|
||||
// TODO: distinguish return code when we are sure that even after defrag there is no space available
|
||||
if (ret == 1) {
|
||||
llama_kv_self_defrag(ctx);
|
||||
ret = ctx->decode(batch);
|
||||
|
||||
if (ret == 1) {
|
||||
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
|
||||
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
if (ret != 0) {
|
||||
const int ret = ctx->decode(batch);
|
||||
if (ret != 0 && ret != 1) {
|
||||
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
||||
}
|
||||
|
||||
|
||||
@@ -50,8 +50,9 @@ struct llama_context {
|
||||
llama_kv_cache * get_kv_self();
|
||||
const llama_kv_cache * get_kv_self() const;
|
||||
|
||||
// return true of the KV cache was updated
|
||||
// TODO: remove
|
||||
void kv_self_update();
|
||||
bool kv_self_update();
|
||||
|
||||
enum llama_pooling_type pooling_type() const;
|
||||
|
||||
|
||||
@@ -3,7 +3,10 @@
|
||||
#include "llama-impl.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-cparams.h"
|
||||
#include "llama-kv-cache.h"
|
||||
|
||||
#include "llama-kv-cache-unified.h"
|
||||
#include "llama-kv-cache-unified-iswa.h"
|
||||
#include "llama-kv-cache-recurrent.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
1132
src/llama-kv-cache-recurrent.cpp
Normal file
1132
src/llama-kv-cache-recurrent.cpp
Normal file
File diff suppressed because it is too large
Load Diff
191
src/llama-kv-cache-recurrent.h
Normal file
191
src/llama-kv-cache-recurrent.h
Normal file
@@ -0,0 +1,191 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama-batch.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-kv-cache.h"
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
//
|
||||
// llama_kv_cache_recurrent
|
||||
//
|
||||
|
||||
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
|
||||
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
||||
class llama_kv_cache_recurrent : public llama_kv_cache {
|
||||
public:
|
||||
llama_kv_cache_recurrent(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max);
|
||||
|
||||
~llama_kv_cache_recurrent() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
void clear() override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
bool update(llama_context & lctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
|
||||
bool prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
// find a contiguous slot of kv cells and emplace the ubatch there
|
||||
bool find_slot(const llama_ubatch & ubatch);
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this
|
||||
int32_t s_copy(int i) const;
|
||||
float s_mask(int i) const;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
||||
uint32_t size = 0; // total number of cells, shared across all sequences
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||
|
||||
// computed before each graph build
|
||||
uint32_t n = 0;
|
||||
|
||||
// TODO: optimize for recurrent state needs
|
||||
struct kv_cell {
|
||||
llama_pos pos = -1;
|
||||
int32_t src = -1; // used to copy states
|
||||
int32_t tail = -1;
|
||||
|
||||
std::set<llama_seq_id> seq_id;
|
||||
|
||||
bool has_seq_id(const llama_seq_id & id) const {
|
||||
return seq_id.find(id) != seq_id.end();
|
||||
}
|
||||
|
||||
bool is_empty() const {
|
||||
return seq_id.empty();
|
||||
}
|
||||
|
||||
bool is_same_seq(const kv_cell & other) const {
|
||||
return seq_id == other.seq_id;
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<kv_cell> cells;
|
||||
|
||||
std::vector<ggml_tensor *> k_l; // per layer
|
||||
std::vector<ggml_tensor *> v_l;
|
||||
|
||||
private:
|
||||
//const llama_model & model;
|
||||
const llama_hparams & hparams;
|
||||
|
||||
const uint32_t n_seq_max = 1;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
class llama_kv_cache_recurrent_state : public llama_memory_state_i {
|
||||
public:
|
||||
// used for errors
|
||||
llama_kv_cache_recurrent_state(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache state
|
||||
llama_kv_cache_recurrent_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_recurrent * kv);
|
||||
|
||||
// used to create a state from a batch
|
||||
llama_kv_cache_recurrent_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_recurrent * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_recurrent_state();
|
||||
|
||||
//
|
||||
// llama_memory_state_i
|
||||
//
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
std::vector<int64_t> & out_ids() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_recurrent_state specific API
|
||||
//
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
uint32_t get_head() const;
|
||||
uint32_t get_size() const;
|
||||
|
||||
ggml_tensor * get_k_l(int32_t il) const;
|
||||
ggml_tensor * get_v_l(int32_t il) const;
|
||||
|
||||
int32_t s_copy(int i) const;
|
||||
float s_mask(int i) const;
|
||||
|
||||
private:
|
||||
const llama_memory_status status;
|
||||
|
||||
llama_kv_cache_recurrent * kv;
|
||||
|
||||
llama_sbatch sbatch;
|
||||
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
//
|
||||
// data needed for building the compute graph for the current ubatch:
|
||||
// TODO: extract all the state like `head` and `n` here
|
||||
//
|
||||
|
||||
const bool is_full = false;
|
||||
};
|
||||
249
src/llama-kv-cache-unified-iswa.cpp
Normal file
249
src/llama-kv-cache-unified-iswa.cpp
Normal file
@@ -0,0 +1,249 @@
|
||||
#include "llama-kv-cache-unified-iswa.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-model.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa
|
||||
//
|
||||
|
||||
llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool swa_full,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad) : hparams(model.hparams) {
|
||||
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
||||
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
||||
|
||||
const uint32_t size_base = kv_size;
|
||||
|
||||
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
|
||||
|
||||
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
||||
if (swa_full) {
|
||||
LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
|
||||
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||
|
||||
size_swa = size_base;
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
||||
|
||||
kv_base = std::make_unique<llama_kv_cache_unified>(
|
||||
model, std::move(filter_base), type_k, type_v,
|
||||
v_trans, offload, size_base, n_seq_max, n_pad,
|
||||
0, LLAMA_SWA_TYPE_NONE);
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
||||
|
||||
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
||||
model, std::move(filter_swa), type_k, type_v,
|
||||
v_trans, offload, size_swa, n_seq_max, n_pad,
|
||||
hparams.n_swa, hparams.swa_type);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::clear() {
|
||||
kv_base->clear();
|
||||
kv_swa ->clear();
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
bool res = true;
|
||||
|
||||
res = res & kv_base->seq_rm(seq_id, p0, p1);
|
||||
res = res & kv_swa ->seq_rm(seq_id, p0, p1);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
|
||||
kv_base->seq_keep(seq_id);
|
||||
kv_swa ->seq_keep(seq_id);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
kv_base->seq_add(seq_id, p0, p1, shift);
|
||||
kv_swa ->seq_add(seq_id, p0, p1, shift);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
kv_base->seq_div(seq_id, p0, p1, d);
|
||||
kv_swa ->seq_div(seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
||||
// the base cache is a superset of the SWA cache, so we can just check the SWA cache
|
||||
return kv_swa->seq_pos_min(seq_id);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
||||
return kv_swa->seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
|
||||
GGML_UNUSED(embd_pooled);
|
||||
|
||||
// TODO: if we fail with split_simple, we should attempt different splitting strategies
|
||||
// but to do that properly, we first have to refactor the batches to be more flexible
|
||||
|
||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
while (sbatch.n_tokens > 0) {
|
||||
auto ubatch = sbatch.split_simple(n_ubatch);
|
||||
|
||||
ubatches.push_back(ubatch);
|
||||
}
|
||||
|
||||
auto heads_base = kv_base->prepare(ubatches);
|
||||
if (heads_base.empty()) {
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
auto heads_swa = kv_swa->prepare(ubatches);
|
||||
if (heads_swa.empty()) {
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
assert(heads_base.size() == heads_swa.size());
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
|
||||
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
|
||||
bool res = false;
|
||||
|
||||
res = res | kv_base->update(lctx);
|
||||
res = res | kv_swa ->update(lctx);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
|
||||
kv_base->defrag_sched(thold);
|
||||
kv_swa ->defrag_sched(thold);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
||||
return kv_base->get_size() == kv_swa->get_size();
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
||||
kv_base->state_write(io, seq_id);
|
||||
kv_swa ->state_write(io, seq_id);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||
kv_base->state_read(io, seq_id);
|
||||
kv_swa ->state_read(io, seq_id);
|
||||
}
|
||||
|
||||
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
|
||||
return kv_base.get();
|
||||
}
|
||||
|
||||
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
||||
return kv_swa.get();
|
||||
}
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa_state
|
||||
//
|
||||
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified_iswa * kv) : status(status) {
|
||||
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base()));
|
||||
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads_base,
|
||||
std::vector<uint32_t> heads_swa,
|
||||
std::vector<llama_ubatch> ubatches)
|
||||
: status(status),
|
||||
sbatch(std::move(sbatch)),
|
||||
ubatches(std::move(ubatches)) {
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches));
|
||||
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
||||
|
||||
bool llama_kv_cache_unified_iswa_state::next() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
state_base->next();
|
||||
state_swa ->next();
|
||||
|
||||
if (++i_next >= ubatches.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa_state::apply() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
bool res = true;
|
||||
|
||||
res = res & state_base->apply();
|
||||
res = res & state_swa ->apply();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return sbatch.out_ids;
|
||||
}
|
||||
|
||||
llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
|
||||
return status;
|
||||
}
|
||||
|
||||
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
return ubatches[i_next];
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return state_base.get();
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return state_swa.get();
|
||||
}
|
||||
136
src/llama-kv-cache-unified-iswa.h
Normal file
136
src/llama-kv-cache-unified-iswa.h
Normal file
@@ -0,0 +1,136 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama-kv-cache-unified.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa
|
||||
//
|
||||
|
||||
// utilizes two instances of llama_kv_cache_unified
|
||||
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
|
||||
|
||||
class llama_kv_cache_unified_iswa : public llama_kv_cache {
|
||||
public:
|
||||
llama_kv_cache_unified_iswa(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool swa_full,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad);
|
||||
|
||||
~llama_kv_cache_unified_iswa() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
void clear() override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
bool update(llama_context & lctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa specific API
|
||||
//
|
||||
|
||||
llama_kv_cache_unified * get_base() const;
|
||||
llama_kv_cache_unified * get_swa () const;
|
||||
|
||||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
||||
};
|
||||
|
||||
class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
|
||||
public:
|
||||
// used for errors
|
||||
llama_kv_cache_unified_iswa_state(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache state
|
||||
llama_kv_cache_unified_iswa_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified_iswa * kv);
|
||||
|
||||
// used to create a state from a batch
|
||||
llama_kv_cache_unified_iswa_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads_base,
|
||||
std::vector<uint32_t> heads_swa,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_unified_iswa_state();
|
||||
|
||||
//
|
||||
// llama_memory_state_i
|
||||
//
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
std::vector<int64_t> & out_ids() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa_state specific API
|
||||
//
|
||||
|
||||
const llama_kv_cache_unified_state * get_base() const;
|
||||
const llama_kv_cache_unified_state * get_swa() const;
|
||||
|
||||
private:
|
||||
const llama_memory_status status;
|
||||
|
||||
//llama_kv_cache_unified_iswa * kv;
|
||||
|
||||
llama_sbatch sbatch;
|
||||
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified_state> state_base;
|
||||
std::unique_ptr<llama_kv_cache_unified_state> state_swa;
|
||||
};
|
||||
1717
src/llama-kv-cache-unified.cpp
Normal file
1717
src/llama-kv-cache-unified.cpp
Normal file
File diff suppressed because it is too large
Load Diff
278
src/llama-kv-cache-unified.h
Normal file
278
src/llama-kv-cache-unified.h
Normal file
@@ -0,0 +1,278 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama-batch.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-kv-cells.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
struct llama_cparams;
|
||||
struct llama_hparams;
|
||||
struct llama_model;
|
||||
struct llama_context;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified
|
||||
//
|
||||
|
||||
class llama_kv_cache_unified : public llama_kv_cache {
|
||||
public:
|
||||
static uint32_t get_padding(const llama_cparams & cparams);
|
||||
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
llama_kv_cache_unified(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type);
|
||||
|
||||
~llama_kv_cache_unified() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
void clear() override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
bool update(llama_context & lctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified specific API
|
||||
//
|
||||
|
||||
uint32_t get_size() const;
|
||||
|
||||
//
|
||||
// graph_build API
|
||||
//
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
|
||||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
|
||||
|
||||
//
|
||||
// preparation API
|
||||
//
|
||||
|
||||
// find places for the provided ubatches in the cache, returns the head locations
|
||||
// return empty vector on failure
|
||||
std::vector<uint32_t> prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
// return the cell position where we can insert the ubatch
|
||||
// return -1 on failure to find a contiguous slot of kv cells
|
||||
int32_t find_slot(const llama_ubatch & ubatch) const;
|
||||
|
||||
// emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
|
||||
void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);
|
||||
|
||||
//
|
||||
// set_input API
|
||||
//
|
||||
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_k_shift (ggml_tensor * dst) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
private:
|
||||
const llama_model & model;
|
||||
const llama_hparams & hparams;
|
||||
|
||||
struct kv_layer {
|
||||
// layer index in the model
|
||||
// note: can be different from the layer index in the KV cache
|
||||
uint32_t il;
|
||||
|
||||
ggml_tensor * k;
|
||||
ggml_tensor * v;
|
||||
};
|
||||
|
||||
bool do_defrag = false;
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
uint32_t head = 0;
|
||||
|
||||
const uint32_t n_seq_max = 1;
|
||||
|
||||
// required padding
|
||||
const uint32_t n_pad = 1;
|
||||
|
||||
// SWA
|
||||
const uint32_t n_swa = 0;
|
||||
|
||||
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
llama_kv_cells_unified cells;
|
||||
|
||||
std::vector<kv_layer> layers;
|
||||
|
||||
// model layer id -> KV cache layer id
|
||||
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
||||
|
||||
// defrag
|
||||
struct {
|
||||
std::vector<uint32_t> ids;
|
||||
} defrag_info;
|
||||
|
||||
// return true if cells have been moved
|
||||
bool defrag_prepare(int32_t n_max_nodes);
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
|
||||
bool is_masked_swa(llama_pos p0, llama_pos p1) const;
|
||||
|
||||
ggml_tensor * build_rope_shift(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * factors,
|
||||
float freq_base,
|
||||
float freq_scale) const;
|
||||
|
||||
llm_graph_result_ptr build_graph_shift(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf) const;
|
||||
|
||||
llm_graph_result_ptr build_graph_defrag(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf) const;
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
class llama_kv_cache_unified_state : public llama_memory_state_i {
|
||||
public:
|
||||
// used for errors
|
||||
llama_kv_cache_unified_state(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache state
|
||||
llama_kv_cache_unified_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified * kv);
|
||||
|
||||
// used to create a state from a batch
|
||||
llama_kv_cache_unified_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_unified_state();
|
||||
|
||||
//
|
||||
// llama_memory_state_i
|
||||
//
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
std::vector<int64_t> & out_ids() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_state specific API
|
||||
//
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
|
||||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
|
||||
|
||||
void set_input_k_shift(ggml_tensor * dst) const;
|
||||
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
private:
|
||||
const llama_memory_status status;
|
||||
|
||||
llama_kv_cache_unified * kv;
|
||||
|
||||
llama_sbatch sbatch;
|
||||
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<uint32_t> heads;
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
//
|
||||
// data needed for building the compute graph for the current ubatch:
|
||||
//
|
||||
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// as the cache gets filled, the benefit from this heuristic disappears
|
||||
int32_t n_kv;
|
||||
|
||||
// the beginning of the current slot in which the ubatch will be inserted
|
||||
int32_t head;
|
||||
};
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,21 +2,7 @@
|
||||
|
||||
#include "llama.h"
|
||||
#include "llama-io.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-memory.h"
|
||||
#include "llama-kv-cells.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
struct llama_cparams;
|
||||
struct llama_hparams;
|
||||
struct llama_model;
|
||||
struct llama_context;
|
||||
|
||||
struct llama_kv_cache : public llama_memory_i {
|
||||
virtual ~llama_kv_cache() = default;
|
||||
@@ -56,581 +42,3 @@ struct llama_kv_cache : public llama_memory_i {
|
||||
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
|
||||
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
|
||||
};
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified
|
||||
//
|
||||
|
||||
class llama_kv_cache_unified : public llama_kv_cache {
|
||||
public:
|
||||
static uint32_t get_padding(const llama_cparams & cparams);
|
||||
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
llama_kv_cache_unified(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type);
|
||||
|
||||
~llama_kv_cache_unified() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
void clear() override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
bool update(llama_context & lctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified specific API
|
||||
//
|
||||
|
||||
uint32_t get_size() const;
|
||||
|
||||
//
|
||||
// graph_build API
|
||||
//
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
|
||||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
|
||||
|
||||
//
|
||||
// preparation API
|
||||
//
|
||||
|
||||
// find places for the provided ubatches in the cache, returns the head locations
|
||||
// return empty vector on failure
|
||||
std::vector<uint32_t> prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
// return the cell position where we can insert the ubatch
|
||||
// return -1 on failure to find a contiguous slot of kv cells
|
||||
int32_t find_slot(const llama_ubatch & ubatch) const;
|
||||
|
||||
// emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
|
||||
void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);
|
||||
|
||||
//
|
||||
// set_input API
|
||||
//
|
||||
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_k_shift (ggml_tensor * dst) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
private:
|
||||
const llama_model & model;
|
||||
const llama_hparams & hparams;
|
||||
|
||||
struct kv_layer {
|
||||
// layer index in the model
|
||||
// note: can be different from the layer index in the KV cache
|
||||
uint32_t il;
|
||||
|
||||
ggml_tensor * k;
|
||||
ggml_tensor * v;
|
||||
};
|
||||
|
||||
bool do_defrag = false;
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
uint32_t head = 0;
|
||||
|
||||
const uint32_t n_seq_max = 1;
|
||||
|
||||
// required padding
|
||||
const uint32_t n_pad = 1;
|
||||
|
||||
// SWA
|
||||
const uint32_t n_swa = 0;
|
||||
|
||||
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
llama_kv_cells_unified cells;
|
||||
|
||||
std::vector<kv_layer> layers;
|
||||
|
||||
// model layer id -> KV cache layer id
|
||||
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
||||
|
||||
// defrag
|
||||
struct {
|
||||
std::vector<uint32_t> ids;
|
||||
} defrag_info;
|
||||
|
||||
// return true if cells have been moved
|
||||
bool defrag_prepare(int32_t n_max_nodes);
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
|
||||
bool is_masked_swa(llama_pos p0, llama_pos p1) const;
|
||||
|
||||
ggml_tensor * build_rope_shift(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * factors,
|
||||
float freq_base,
|
||||
float freq_scale) const;
|
||||
|
||||
llm_graph_result_ptr build_graph_shift(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf) const;
|
||||
|
||||
llm_graph_result_ptr build_graph_defrag(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf) const;
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
class llama_kv_cache_unified_state : public llama_memory_state_i {
|
||||
public:
|
||||
// used for errors
|
||||
llama_kv_cache_unified_state(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache state
|
||||
llama_kv_cache_unified_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified * kv);
|
||||
|
||||
// used to create a state from a batch
|
||||
llama_kv_cache_unified_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_unified_state();
|
||||
|
||||
//
|
||||
// llama_memory_state_i
|
||||
//
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
std::vector<int64_t> & out_ids() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_state specific API
|
||||
//
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
|
||||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
|
||||
|
||||
void set_input_k_shift(ggml_tensor * dst) const;
|
||||
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
private:
|
||||
const llama_memory_status status;
|
||||
|
||||
llama_kv_cache_unified * kv;
|
||||
|
||||
llama_sbatch sbatch;
|
||||
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<uint32_t> heads;
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
//
|
||||
// data needed for building the compute graph for the current ubatch:
|
||||
//
|
||||
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// as the cache gets filled, the benefit from this heuristic disappears
|
||||
int32_t n_kv;
|
||||
|
||||
// the beginning of the current slot in which the ubatch will be inserted
|
||||
int32_t head;
|
||||
};
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa
|
||||
//
|
||||
|
||||
// utilizes two instances of llama_kv_cache_unified
|
||||
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
|
||||
|
||||
class llama_kv_cache_unified_iswa : public llama_kv_cache {
|
||||
public:
|
||||
llama_kv_cache_unified_iswa(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool swa_full,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_batch,
|
||||
uint32_t n_pad);
|
||||
|
||||
~llama_kv_cache_unified_iswa() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
void clear() override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
bool update(llama_context & lctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa specific API
|
||||
//
|
||||
|
||||
llama_kv_cache_unified * get_base() const;
|
||||
llama_kv_cache_unified * get_swa () const;
|
||||
|
||||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
||||
};
|
||||
|
||||
class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
|
||||
public:
|
||||
// used for errors
|
||||
llama_kv_cache_unified_iswa_state(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache state
|
||||
llama_kv_cache_unified_iswa_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified_iswa * kv);
|
||||
|
||||
// used to create a state from a batch
|
||||
llama_kv_cache_unified_iswa_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads_base,
|
||||
std::vector<uint32_t> heads_swa,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_unified_iswa_state();
|
||||
|
||||
//
|
||||
// llama_memory_state_i
|
||||
//
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
std::vector<int64_t> & out_ids() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa_state specific API
|
||||
//
|
||||
|
||||
const llama_kv_cache_unified_state * get_base() const;
|
||||
const llama_kv_cache_unified_state * get_swa() const;
|
||||
|
||||
private:
|
||||
const llama_memory_status status;
|
||||
|
||||
//llama_kv_cache_unified_iswa * kv;
|
||||
|
||||
llama_sbatch sbatch;
|
||||
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified_state> state_base;
|
||||
std::unique_ptr<llama_kv_cache_unified_state> state_swa;
|
||||
};
|
||||
|
||||
//
|
||||
// llama_kv_cache_recurrent
|
||||
//
|
||||
|
||||
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
|
||||
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
||||
class llama_kv_cache_recurrent : public llama_kv_cache {
|
||||
public:
|
||||
llama_kv_cache_recurrent(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max);
|
||||
|
||||
~llama_kv_cache_recurrent() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
void clear() override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
bool update(llama_context & lctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
|
||||
bool prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
// find a contiguous slot of kv cells and emplace the ubatch there
|
||||
bool find_slot(const llama_ubatch & ubatch);
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this
|
||||
int32_t s_copy(int i) const;
|
||||
float s_mask(int i) const;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
||||
uint32_t size = 0; // total number of cells, shared across all sequences
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||
|
||||
// computed before each graph build
|
||||
uint32_t n = 0;
|
||||
|
||||
// TODO: optimize for recurrent state needs
|
||||
struct kv_cell {
|
||||
llama_pos pos = -1;
|
||||
int32_t src = -1; // used to copy states
|
||||
int32_t tail = -1;
|
||||
|
||||
std::set<llama_seq_id> seq_id;
|
||||
|
||||
bool has_seq_id(const llama_seq_id & id) const {
|
||||
return seq_id.find(id) != seq_id.end();
|
||||
}
|
||||
|
||||
bool is_empty() const {
|
||||
return seq_id.empty();
|
||||
}
|
||||
|
||||
bool is_same_seq(const kv_cell & other) const {
|
||||
return seq_id == other.seq_id;
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<kv_cell> cells;
|
||||
|
||||
std::vector<ggml_tensor *> k_l; // per layer
|
||||
std::vector<ggml_tensor *> v_l;
|
||||
|
||||
private:
|
||||
//const llama_model & model;
|
||||
const llama_hparams & hparams;
|
||||
|
||||
const uint32_t n_seq_max = 1;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
class llama_kv_cache_recurrent_state : public llama_memory_state_i {
|
||||
public:
|
||||
// used for errors
|
||||
llama_kv_cache_recurrent_state(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache state
|
||||
llama_kv_cache_recurrent_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_recurrent * kv);
|
||||
|
||||
// used to create a state from a batch
|
||||
llama_kv_cache_recurrent_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_recurrent * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_recurrent_state();
|
||||
|
||||
//
|
||||
// llama_memory_state_i
|
||||
//
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
std::vector<int64_t> & out_ids() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_recurrent_state specific API
|
||||
//
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
uint32_t get_head() const;
|
||||
uint32_t get_size() const;
|
||||
|
||||
ggml_tensor * get_k_l(int32_t il) const;
|
||||
ggml_tensor * get_v_l(int32_t il) const;
|
||||
|
||||
int32_t s_copy(int i) const;
|
||||
float s_mask(int i) const;
|
||||
|
||||
private:
|
||||
const llama_memory_status status;
|
||||
|
||||
llama_kv_cache_recurrent * kv;
|
||||
|
||||
llama_sbatch sbatch;
|
||||
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
//
|
||||
// data needed for building the compute graph for the current ubatch:
|
||||
// TODO: extract all the state like `head` and `n` here
|
||||
//
|
||||
|
||||
const bool is_full = false;
|
||||
};
|
||||
|
||||
@@ -5,7 +5,10 @@
|
||||
#include "llama-batch.h"
|
||||
#include "llama-cparams.h"
|
||||
#include "llama-model-loader.h"
|
||||
#include "llama-kv-cache.h"
|
||||
|
||||
#include "llama-kv-cache-unified.h"
|
||||
#include "llama-kv-cache-unified-iswa.h"
|
||||
#include "llama-kv-cache-recurrent.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
@@ -13230,7 +13233,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
params.swa_full,
|
||||
cparams.n_ctx,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_batch,
|
||||
cparams.n_ubatch,
|
||||
padding);
|
||||
} else {
|
||||
GGML_ASSERT(!hparams.is_swa_any());
|
||||
@@ -13593,6 +13596,10 @@ int32_t llama_model_n_head_kv(const llama_model * model) {
|
||||
return model->hparams.n_head_kv();
|
||||
}
|
||||
|
||||
int32_t llama_model_n_swa(const llama_model * model) {
|
||||
return model->hparams.n_swa;
|
||||
}
|
||||
|
||||
// deprecated
|
||||
int32_t llama_n_ctx_train(const llama_model * model) {
|
||||
return llama_model_n_ctx_train(model);
|
||||
|
||||
@@ -1041,6 +1041,15 @@ static void test_template_output_parsers() {
|
||||
"<tool_call>\n"
|
||||
"{\"name\": \"python\", \"arguments\": {\"code\":\"# This is a program:\\nprint('hey')\"}}\n"
|
||||
"</tool_call>");
|
||||
assert_msg_equals(
|
||||
simple_assist_msg("", /* reasoning_content= */ "<tool_call>nah uhg</tool_call>"),
|
||||
common_chat_parse(
|
||||
"<think><tool_call>nah uhg</tool_call>",
|
||||
/* is_partial= */ false,
|
||||
{
|
||||
/* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
}));
|
||||
}
|
||||
{
|
||||
auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja");
|
||||
|
||||
@@ -315,7 +315,7 @@ static void print_usage(int /* argc */, char ** argv) {
|
||||
printf(" --numa <distribute|isolate|numactl> numa mode (default: disabled)\n");
|
||||
printf(" -r, --repetitions <n> number of times to repeat each test (default: %d)\n",
|
||||
cmd_params_defaults.reps);
|
||||
printf(" --prio <0|1|2|3> process/thread priority (default: %d)\n",
|
||||
printf(" --prio <-1|0|1|2|3> process/thread priority (default: %d)\n",
|
||||
cmd_params_defaults.prio);
|
||||
printf(" --delay <0...N> (seconds) delay between each test (default: %d)\n",
|
||||
cmd_params_defaults.delay);
|
||||
|
||||
Binary file not shown.
@@ -2016,11 +2016,6 @@ struct server_context {
|
||||
params_base.n_cache_reuse = 0;
|
||||
SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled");
|
||||
}
|
||||
|
||||
if (!params_base.speculative.model.path.empty()) {
|
||||
SRV_ERR("%s\n", "err: speculative decode is not supported by this context");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -3215,8 +3210,14 @@ struct server_context {
|
||||
|
||||
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
|
||||
const auto pos_min = llama_kv_self_seq_pos_min(ctx, slot.id);
|
||||
if (pos_min > 0) {
|
||||
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min);
|
||||
if (pos_min == -1) {
|
||||
SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min);
|
||||
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
||||
}
|
||||
|
||||
const auto n_swa = llama_model_n_swa(model);
|
||||
if (pos_min > slot.n_past - n_swa) {
|
||||
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
|
||||
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
|
||||
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||
llama_kv_self_seq_rm(ctx, slot.id, 0, -1);
|
||||
@@ -3431,7 +3432,7 @@ struct server_context {
|
||||
// retry with half the batch size to try to find a free slot in the KV cache
|
||||
n_batch /= 2;
|
||||
|
||||
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
|
||||
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
|
||||
|
||||
continue; // continue loop of n_batch
|
||||
}
|
||||
|
||||
@@ -5,21 +5,24 @@ import { AppContextProvider, useAppContext } from './utils/app.context';
|
||||
import ChatScreen from './components/ChatScreen';
|
||||
import SettingDialog from './components/SettingDialog';
|
||||
import { Toaster } from 'react-hot-toast';
|
||||
import { ModalProvider } from './components/ModalProvider';
|
||||
|
||||
function App() {
|
||||
return (
|
||||
<HashRouter>
|
||||
<div className="flex flex-row drawer lg:drawer-open">
|
||||
<AppContextProvider>
|
||||
<Routes>
|
||||
<Route element={<AppLayout />}>
|
||||
<Route path="/chat/:convId" element={<ChatScreen />} />
|
||||
<Route path="*" element={<ChatScreen />} />
|
||||
</Route>
|
||||
</Routes>
|
||||
</AppContextProvider>
|
||||
</div>
|
||||
</HashRouter>
|
||||
<ModalProvider>
|
||||
<HashRouter>
|
||||
<div className="flex flex-row drawer lg:drawer-open">
|
||||
<AppContextProvider>
|
||||
<Routes>
|
||||
<Route element={<AppLayout />}>
|
||||
<Route path="/chat/:convId" element={<ChatScreen />} />
|
||||
<Route path="*" element={<ChatScreen />} />
|
||||
</Route>
|
||||
</Routes>
|
||||
</AppContextProvider>
|
||||
</div>
|
||||
</HashRouter>
|
||||
</ModalProvider>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
151
tools/server/webui/src/components/ModalProvider.tsx
Normal file
151
tools/server/webui/src/components/ModalProvider.tsx
Normal file
@@ -0,0 +1,151 @@
|
||||
import React, { createContext, useState, useContext } from 'react';
|
||||
|
||||
type ModalContextType = {
|
||||
showConfirm: (message: string) => Promise<boolean>;
|
||||
showPrompt: (
|
||||
message: string,
|
||||
defaultValue?: string
|
||||
) => Promise<string | undefined>;
|
||||
showAlert: (message: string) => Promise<void>;
|
||||
};
|
||||
const ModalContext = createContext<ModalContextType>(null!);
|
||||
|
||||
interface ModalState<T> {
|
||||
isOpen: boolean;
|
||||
message: string;
|
||||
defaultValue?: string;
|
||||
resolve: ((value: T) => void) | null;
|
||||
}
|
||||
|
||||
export function ModalProvider({ children }: { children: React.ReactNode }) {
|
||||
const [confirmState, setConfirmState] = useState<ModalState<boolean>>({
|
||||
isOpen: false,
|
||||
message: '',
|
||||
resolve: null,
|
||||
});
|
||||
const [promptState, setPromptState] = useState<
|
||||
ModalState<string | undefined>
|
||||
>({ isOpen: false, message: '', resolve: null });
|
||||
const [alertState, setAlertState] = useState<ModalState<void>>({
|
||||
isOpen: false,
|
||||
message: '',
|
||||
resolve: null,
|
||||
});
|
||||
const inputRef = React.useRef<HTMLInputElement>(null);
|
||||
|
||||
const showConfirm = (message: string): Promise<boolean> => {
|
||||
return new Promise((resolve) => {
|
||||
setConfirmState({ isOpen: true, message, resolve });
|
||||
});
|
||||
};
|
||||
|
||||
const showPrompt = (
|
||||
message: string,
|
||||
defaultValue?: string
|
||||
): Promise<string | undefined> => {
|
||||
return new Promise((resolve) => {
|
||||
setPromptState({ isOpen: true, message, defaultValue, resolve });
|
||||
});
|
||||
};
|
||||
|
||||
const showAlert = (message: string): Promise<void> => {
|
||||
return new Promise((resolve) => {
|
||||
setAlertState({ isOpen: true, message, resolve });
|
||||
});
|
||||
};
|
||||
|
||||
const handleConfirm = (result: boolean) => {
|
||||
confirmState.resolve?.(result);
|
||||
setConfirmState({ isOpen: false, message: '', resolve: null });
|
||||
};
|
||||
|
||||
const handlePrompt = (result?: string) => {
|
||||
promptState.resolve?.(result);
|
||||
setPromptState({ isOpen: false, message: '', resolve: null });
|
||||
};
|
||||
|
||||
const handleAlertClose = () => {
|
||||
alertState.resolve?.();
|
||||
setAlertState({ isOpen: false, message: '', resolve: null });
|
||||
};
|
||||
|
||||
return (
|
||||
<ModalContext.Provider value={{ showConfirm, showPrompt, showAlert }}>
|
||||
{children}
|
||||
|
||||
{/* Confirm Modal */}
|
||||
{confirmState.isOpen && (
|
||||
<dialog className="modal modal-open z-[1100]">
|
||||
<div className="modal-box">
|
||||
<h3 className="font-bold text-lg">{confirmState.message}</h3>
|
||||
<div className="modal-action">
|
||||
<button
|
||||
className="btn btn-ghost"
|
||||
onClick={() => handleConfirm(false)}
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
className="btn btn-error"
|
||||
onClick={() => handleConfirm(true)}
|
||||
>
|
||||
Confirm
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</dialog>
|
||||
)}
|
||||
|
||||
{/* Prompt Modal */}
|
||||
{promptState.isOpen && (
|
||||
<dialog className="modal modal-open z-[1100]">
|
||||
<div className="modal-box">
|
||||
<h3 className="font-bold text-lg">{promptState.message}</h3>
|
||||
<input
|
||||
type="text"
|
||||
className="input input-bordered w-full mt-2"
|
||||
defaultValue={promptState.defaultValue}
|
||||
ref={inputRef}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter') {
|
||||
handlePrompt((e.target as HTMLInputElement).value);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<div className="modal-action">
|
||||
<button className="btn btn-ghost" onClick={() => handlePrompt()}>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
className="btn btn-primary"
|
||||
onClick={() => handlePrompt(inputRef.current?.value)}
|
||||
>
|
||||
Submit
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</dialog>
|
||||
)}
|
||||
|
||||
{/* Alert Modal */}
|
||||
{alertState.isOpen && (
|
||||
<dialog className="modal modal-open z-[1100]">
|
||||
<div className="modal-box">
|
||||
<h3 className="font-bold text-lg">{alertState.message}</h3>
|
||||
<div className="modal-action">
|
||||
<button className="btn" onClick={handleAlertClose}>
|
||||
OK
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</dialog>
|
||||
)}
|
||||
</ModalContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useModals() {
|
||||
const context = useContext(ModalContext);
|
||||
if (!context) throw new Error('useModals must be used within ModalProvider');
|
||||
return context;
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
SquaresPlusIcon,
|
||||
} from '@heroicons/react/24/outline';
|
||||
import { OpenInNewTab } from '../utils/common';
|
||||
import { useModals } from './ModalProvider';
|
||||
|
||||
type SettKey = keyof typeof CONFIG_DEFAULT;
|
||||
|
||||
@@ -282,14 +283,15 @@ export default function SettingDialog({
|
||||
const [localConfig, setLocalConfig] = useState<typeof CONFIG_DEFAULT>(
|
||||
JSON.parse(JSON.stringify(config))
|
||||
);
|
||||
const { showConfirm, showAlert } = useModals();
|
||||
|
||||
const resetConfig = () => {
|
||||
if (window.confirm('Are you sure you want to reset all settings?')) {
|
||||
const resetConfig = async () => {
|
||||
if (await showConfirm('Are you sure you want to reset all settings?')) {
|
||||
setLocalConfig(CONFIG_DEFAULT);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSave = () => {
|
||||
const handleSave = async () => {
|
||||
// copy the local config to prevent direct mutation
|
||||
const newConfig: typeof CONFIG_DEFAULT = JSON.parse(
|
||||
JSON.stringify(localConfig)
|
||||
@@ -302,14 +304,14 @@ export default function SettingDialog({
|
||||
const mustBeNumeric = isNumeric(CONFIG_DEFAULT[key as SettKey]);
|
||||
if (mustBeString) {
|
||||
if (!isString(value)) {
|
||||
alert(`Value for ${key} must be string`);
|
||||
await showAlert(`Value for ${key} must be string`);
|
||||
return;
|
||||
}
|
||||
} else if (mustBeNumeric) {
|
||||
const trimmedValue = value.toString().trim();
|
||||
const numVal = Number(trimmedValue);
|
||||
if (isNaN(numVal) || !isNumeric(numVal) || trimmedValue.length === 0) {
|
||||
alert(`Value for ${key} must be numeric`);
|
||||
await showAlert(`Value for ${key} must be numeric`);
|
||||
return;
|
||||
}
|
||||
// force conversion to number
|
||||
@@ -317,7 +319,7 @@ export default function SettingDialog({
|
||||
newConfig[key] = numVal;
|
||||
} else if (mustBeBoolean) {
|
||||
if (!isBoolean(value)) {
|
||||
alert(`Value for ${key} must be boolean`);
|
||||
await showAlert(`Value for ${key} must be boolean`);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -14,6 +14,7 @@ import {
|
||||
import { BtnWithTooltips } from '../utils/common';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import toast from 'react-hot-toast';
|
||||
import { useModals } from './ModalProvider';
|
||||
|
||||
export default function Sidebar() {
|
||||
const params = useParams();
|
||||
@@ -38,6 +39,7 @@ export default function Sidebar() {
|
||||
StorageUtils.offConversationChanged(handleConversationChange);
|
||||
};
|
||||
}, []);
|
||||
const { showConfirm, showPrompt } = useModals();
|
||||
|
||||
const groupedConv = useMemo(
|
||||
() => groupConversationsByDate(conversations),
|
||||
@@ -130,7 +132,7 @@ export default function Sidebar() {
|
||||
onSelect={() => {
|
||||
navigate(`/chat/${conv.id}`);
|
||||
}}
|
||||
onDelete={() => {
|
||||
onDelete={async () => {
|
||||
if (isGenerating(conv.id)) {
|
||||
toast.error(
|
||||
'Cannot delete conversation while generating'
|
||||
@@ -138,7 +140,7 @@ export default function Sidebar() {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
window.confirm(
|
||||
await showConfirm(
|
||||
'Are you sure to delete this conversation?'
|
||||
)
|
||||
) {
|
||||
@@ -167,14 +169,14 @@ export default function Sidebar() {
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
}}
|
||||
onRename={() => {
|
||||
onRename={async () => {
|
||||
if (isGenerating(conv.id)) {
|
||||
toast.error(
|
||||
'Cannot rename conversation while generating'
|
||||
);
|
||||
return;
|
||||
}
|
||||
const newName = window.prompt(
|
||||
const newName = await showPrompt(
|
||||
'Enter new name for the conversation',
|
||||
conv.name
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user