mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-05-07 16:57:34 +03:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f0ab726f7 | ||
|
|
e82aaf2587 | ||
|
|
27aef3dd91 | ||
|
|
45155597aa | ||
|
|
80afa33aad | ||
|
|
b42c7fa5b8 | ||
|
|
d77599234e |
@@ -232,34 +232,6 @@ static struct llama_sampler * common_reasoning_budget_init_state(
|
||||
);
|
||||
}
|
||||
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
const std::vector<llama_token> & prefill_tokens) {
|
||||
// Determine initial state from prefill: COUNTING if the prefill begins with
|
||||
// the start sequence but does not also contain the end sequence after it.
|
||||
common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE;
|
||||
if (!prefill_tokens.empty() && !start_tokens.empty() &&
|
||||
prefill_tokens.size() >= start_tokens.size() &&
|
||||
std::equal(start_tokens.begin(), start_tokens.end(), prefill_tokens.begin())) {
|
||||
initial_state = REASONING_BUDGET_COUNTING;
|
||||
// If the end sequence also follows the start in the prefill, reasoning
|
||||
// was opened and immediately closed — stay IDLE.
|
||||
if (!end_tokens.empty() &&
|
||||
prefill_tokens.size() >= start_tokens.size() + end_tokens.size()) {
|
||||
auto end_start = prefill_tokens.end() - (ptrdiff_t) end_tokens.size();
|
||||
if (end_start >= prefill_tokens.begin() + (ptrdiff_t) start_tokens.size() &&
|
||||
std::equal(end_tokens.begin(), end_tokens.end(), end_start)) {
|
||||
initial_state = REASONING_BUDGET_IDLE;
|
||||
}
|
||||
}
|
||||
}
|
||||
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
|
||||
}
|
||||
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
|
||||
@@ -29,10 +29,7 @@ enum common_reasoning_budget_state {
|
||||
// end_tokens - token sequence for natural deactivation
|
||||
// forced_tokens - token sequence forced when budget expires
|
||||
// budget - max tokens allowed in the reasoning block
|
||||
// prefill_tokens - tokens already present in the prompt (generation prompt);
|
||||
// used to determine the initial state: COUNTING if they begin
|
||||
// with start_tokens (but don't also end with end_tokens),
|
||||
// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING.
|
||||
// initial_state - initial state
|
||||
//
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
@@ -40,16 +37,6 @@ struct llama_sampler * common_reasoning_budget_init(
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
const std::vector<llama_token> & prefill_tokens = {});
|
||||
|
||||
// Variant that takes an explicit initial state (used by tests and clone).
|
||||
// COUNTING with budget <= 0 is promoted to FORCING.
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
common_reasoning_budget_state initial_state);
|
||||
common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE);
|
||||
|
||||
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl);
|
||||
|
||||
@@ -260,32 +260,35 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
}
|
||||
}
|
||||
|
||||
// Compute prefill tokens from the generation prompt
|
||||
std::vector<llama_token> prefill_tokens;
|
||||
if (!params.generation_prompt.empty()) {
|
||||
GGML_ASSERT(vocab != nullptr);
|
||||
auto tokens = common_tokenize(vocab, params.generation_prompt, false, true);
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
std::string piece = common_token_to_piece(vocab, tokens[i], true);
|
||||
if (i == 0 && std::isspace(piece[0]) && !std::isspace(params.generation_prompt[0])) {
|
||||
// Some tokenizers will add a space before the first special token, need to exclude
|
||||
continue;
|
||||
}
|
||||
LOG_DBG("%s: prefill token: %d = %s\n", __func__, tokens[i], piece.c_str());
|
||||
prefill_tokens.push_back(tokens[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Feed generation prompt tokens to the grammar sampler so it advances past
|
||||
// tokens the template already placed in the prompt.
|
||||
// Only applies to output-format and tool-call grammars; user-supplied grammars must not be prefilled.
|
||||
std::vector<llama_token> prefill_tokens;
|
||||
if (!params.generation_prompt.empty() && common_grammar_needs_prefill(params.grammar)) {
|
||||
GGML_ASSERT(vocab != nullptr);
|
||||
prefill_tokens = common_tokenize(vocab, params.generation_prompt, false, true);
|
||||
if (!prefill_tokens.empty()) {
|
||||
std::string first_token = common_token_to_piece(vocab, prefill_tokens[0], true);
|
||||
if (std::isspace(first_token[0]) && !std::isspace(params.generation_prompt[0])) {
|
||||
// Some tokenizers will add a space before the first special token, need to remove
|
||||
prefill_tokens = std::vector<llama_token>(prefill_tokens.begin() + 1, prefill_tokens.end());
|
||||
}
|
||||
}
|
||||
|
||||
if (grmr && !params.grammar_lazy) {
|
||||
try {
|
||||
for (const auto & token : prefill_tokens) {
|
||||
llama_sampler_accept(grmr, token);
|
||||
LOG_DBG("%s: accepted prefill token (%d)\n", __func__, token);
|
||||
}
|
||||
} catch (std::exception &e) {
|
||||
LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__,
|
||||
common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str());
|
||||
throw e;
|
||||
if (grmr && !params.grammar_lazy && common_grammar_needs_prefill(params.grammar)) {
|
||||
try {
|
||||
for (const auto & token : prefill_tokens) {
|
||||
llama_sampler_accept(grmr, token);
|
||||
LOG_DBG("%s: grammar accepted prefill token (%d)\n", __func__, token);
|
||||
}
|
||||
} catch (std::exception &e) {
|
||||
LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__,
|
||||
common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -296,8 +299,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
params.reasoning_budget_start,
|
||||
params.reasoning_budget_end,
|
||||
params.reasoning_budget_forced,
|
||||
params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens,
|
||||
prefill_tokens);
|
||||
params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens);
|
||||
|
||||
for (const auto & token : prefill_tokens) {
|
||||
llama_sampler_accept(rbudget, token);
|
||||
LOG_DBG("%s: reasoning-budget accepted prefill token (%d)\n", __func__, token);
|
||||
}
|
||||
}
|
||||
|
||||
if (params.has_logit_bias()) {
|
||||
@@ -431,7 +438,7 @@ static bool grammar_should_apply(struct common_sampler * gsmpl) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool is_generated) {
|
||||
if (!gsmpl) {
|
||||
return;
|
||||
}
|
||||
@@ -439,9 +446,11 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
// grammar_should_apply() checks the reasoning budget state, so calculate this before we accept
|
||||
accept_grammar = accept_grammar && grammar_should_apply(gsmpl);
|
||||
const auto accept_grammar = is_generated && grammar_should_apply(gsmpl);
|
||||
|
||||
llama_sampler_accept(gsmpl->rbudget, token);
|
||||
if (gsmpl->rbudget && is_generated) {
|
||||
llama_sampler_accept(gsmpl->rbudget, token);
|
||||
}
|
||||
|
||||
if (gsmpl->grmr && accept_grammar) {
|
||||
llama_sampler_accept(gsmpl->grmr, token);
|
||||
|
||||
@@ -41,8 +41,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
|
||||
void common_sampler_free(struct common_sampler * gsmpl);
|
||||
|
||||
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar);
|
||||
// if is_generated is true, the token is accepted by the sampling chain, the reasoning budget sampler, and the grammar sampler
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool is_generated);
|
||||
void common_sampler_reset (struct common_sampler * gsmpl);
|
||||
struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
|
||||
|
||||
|
||||
@@ -167,8 +167,6 @@ struct common_speculative_checkpoint {
|
||||
size_t size() const {
|
||||
return data.size();
|
||||
}
|
||||
|
||||
size_t ckpt_size = 0;
|
||||
};
|
||||
|
||||
struct common_speculative_state_draft : public common_speculative_state {
|
||||
@@ -176,7 +174,7 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
llama_context * ctx_dft;
|
||||
|
||||
bool use_ckpt = false;
|
||||
struct common_speculative_checkpoint ckpt;
|
||||
common_speculative_checkpoint ckpt;
|
||||
|
||||
common_sampler * smpl;
|
||||
|
||||
@@ -249,26 +247,16 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
|
||||
void begin(const llama_tokens & prompt) override {
|
||||
if (use_ckpt && ckpt.size() > 0) {
|
||||
// delete checkpoint
|
||||
LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n",
|
||||
__func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024);
|
||||
ckpt.pos_min = 0;
|
||||
ckpt.pos_max = 0;
|
||||
ckpt.n_tokens = 0;
|
||||
ckpt.ckpt_size = 0;
|
||||
ckpt.data.clear();
|
||||
}
|
||||
void begin(const llama_tokens & /*prompt*/) override {
|
||||
}
|
||||
|
||||
size_t draft_create_checkpoint(int n_tokens_prompt, int n_tokens_batch) {
|
||||
size_t create_checkpoint(int n_tokens_prompt) {
|
||||
int slot_id = 0;
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id);
|
||||
ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id);
|
||||
ckpt.n_tokens = n_tokens_prompt - n_tokens_batch;
|
||||
ckpt.n_tokens = n_tokens_prompt;
|
||||
ckpt.data.resize(checkpoint_size);
|
||||
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
@@ -281,13 +269,13 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
return n;
|
||||
}
|
||||
|
||||
size_t draft_restore_checkpoint(size_t ckpt_size_part_expected) {
|
||||
size_t restore_checkpoint() {
|
||||
int slot_id = 0;
|
||||
LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max);
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
if (n != ckpt_size_part_expected) {
|
||||
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
|
||||
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n);
|
||||
if (n != ckpt.size()) {
|
||||
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu",
|
||||
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size());
|
||||
}
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1);
|
||||
|
||||
@@ -346,13 +334,18 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
|
||||
const int i_start = std::max<int>(0, (int) prompt_cur.size() - n_ctx);
|
||||
|
||||
if (use_ckpt && i_start > 0) {
|
||||
LOG_WRN("%s: context shift is not supported with checkpoint-based contexts - skipping\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
// reuse as much as possible from the old draft context
|
||||
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
|
||||
for (int i = 0; i < (int) prompt_dft.size(); ++i) {
|
||||
int cur = 0;
|
||||
while (i_start + cur < (int) prompt_cur.size() &&
|
||||
i + cur < (int) prompt_dft.size() &&
|
||||
prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
|
||||
i + cur < (int) prompt_dft.size() &&
|
||||
prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
|
||||
cur++;
|
||||
}
|
||||
|
||||
@@ -360,21 +353,26 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
reuse_i = i;
|
||||
reuse_n = cur;
|
||||
}
|
||||
|
||||
if (use_ckpt) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n",
|
||||
__func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size());
|
||||
if (use_ckpt && ckpt.ckpt_size == 0 && reuse_n > 0) {
|
||||
LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n",
|
||||
__func__, reuse_i, reuse_n);
|
||||
if (use_ckpt && ckpt.n_tokens > reuse_n) {
|
||||
LOG_DBG("%s: checkpoint (n_tokens = %d) is outdated -> delete it\n", __func__, (int) ckpt.n_tokens);
|
||||
|
||||
reuse_i = 0;
|
||||
reuse_n = 0;
|
||||
|
||||
ckpt = {};
|
||||
}
|
||||
|
||||
result.clear();
|
||||
result.reserve(sparams.n_max);
|
||||
|
||||
bool needs_ckpt = use_ckpt && prompt_dft.size() > 0;
|
||||
if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) {
|
||||
llama_memory_clear(mem_dft, false);
|
||||
prompt_dft.clear();
|
||||
@@ -393,50 +391,38 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
return;
|
||||
}
|
||||
|
||||
bool do_restore = false;
|
||||
if (prompt_dft.size() > prompt_cur.size() && reuse_i + reuse_n < (int64_t) prompt_dft.size()) {
|
||||
// This can happen after a partial acceptance (speculative decoding with checkpoints)
|
||||
LOG_DBG("%s: #prompt_dft=%zu, #prompt_cur=%zu, shorten draft\n",
|
||||
__func__, prompt_dft.size(), prompt_cur.size());
|
||||
prompt_dft.resize(prompt_cur.size());
|
||||
do_restore = true;
|
||||
}
|
||||
|
||||
if (reuse_i > 0) {
|
||||
GGML_ASSERT(!use_ckpt);
|
||||
|
||||
bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
|
||||
if (!is_removed) {
|
||||
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i);
|
||||
return;
|
||||
}
|
||||
llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
|
||||
|
||||
prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
|
||||
}
|
||||
|
||||
if (reuse_n < (int) prompt_dft.size() || do_restore) {
|
||||
if (reuse_n < (int) prompt_dft.size()) {
|
||||
if (use_ckpt) {
|
||||
if (ckpt.n_tokens > (int64_t) prompt_dft.size()) {
|
||||
LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n",
|
||||
__func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size());
|
||||
if (ckpt.n_tokens > 0) {
|
||||
LOG_DBG("%s: restoring checkpoint, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size());
|
||||
restore_checkpoint();
|
||||
reuse_n = ckpt.n_tokens;
|
||||
prompt_dft.resize(reuse_n);
|
||||
}
|
||||
draft_restore_checkpoint(ckpt.ckpt_size);
|
||||
reuse_n = ckpt.n_tokens;
|
||||
prompt_dft.resize(reuse_n);
|
||||
needs_ckpt = false;
|
||||
} else {
|
||||
bool is_removed = llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
|
||||
const bool is_removed = llama_memory_seq_rm(mem_dft, 0, reuse_n, -1);
|
||||
if (!is_removed) {
|
||||
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n",
|
||||
__func__, reuse_n, prompt_dft.size());
|
||||
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size());
|
||||
return;
|
||||
}
|
||||
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (needs_ckpt) {
|
||||
ckpt.ckpt_size = draft_create_checkpoint(prompt_dft.size(), batch.n_tokens);
|
||||
}
|
||||
|
||||
// prepare a batch to evaluate any new tokens in the prompt
|
||||
common_batch_clear(batch);
|
||||
|
||||
@@ -450,12 +436,17 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
// we should rarely end-up here during normal decoding
|
||||
if (batch.n_tokens > 0) {
|
||||
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
|
||||
LOG_DBG("%s: draft prompt batch: %d tokens\n", __func__, batch.n_tokens);
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0 && ret != 1) {
|
||||
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n",
|
||||
__func__, ret, prompt_cur.size());
|
||||
}
|
||||
|
||||
if (use_ckpt) {
|
||||
create_checkpoint(prompt_dft.size());
|
||||
}
|
||||
}
|
||||
|
||||
const llama_pos n_past = prompt_dft.size();
|
||||
@@ -784,17 +775,15 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
||||
}
|
||||
|
||||
void accept(uint16_t n_accepted) override {
|
||||
if (verbose) {
|
||||
LOG_INF("%s: accepted %d tokens from %zu drafted tokens\n", __func__, n_accepted, n_draft_last);
|
||||
}
|
||||
|
||||
// compute acceptance fraction if we have a recorded draft length
|
||||
if (n_draft_last > 0) {
|
||||
const double f_acc = (double)n_accepted / (double)n_draft_last;
|
||||
if (f_acc < 0.5) {
|
||||
n_low++;
|
||||
if (n_low >= 3) {
|
||||
LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low);
|
||||
if (verbose) {
|
||||
LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low);
|
||||
}
|
||||
|
||||
mod.reset();
|
||||
n_low = 0;
|
||||
|
||||
@@ -110,13 +110,21 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (
|
||||
llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
|
||||
llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
|
||||
llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
|
||||
llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)
|
||||
) {
|
||||
LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
|
||||
if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
|
||||
(llama_vocab_get_add_bos(vocab_tgt) && llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft))) {
|
||||
LOG_ERR("%s: draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
|
||||
__func__,
|
||||
llama_vocab_get_add_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_dft),
|
||||
llama_vocab_bos(vocab_tgt), llama_vocab_bos(vocab_dft));
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
|
||||
(llama_vocab_get_add_eos(vocab_tgt) && llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft))) {
|
||||
LOG_ERR("%s: draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
|
||||
__func__,
|
||||
llama_vocab_get_add_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_dft),
|
||||
llama_vocab_eos(vocab_tgt), llama_vocab_eos(vocab_dft));
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -137,11 +145,12 @@ int main(int argc, char ** argv) {
|
||||
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
||||
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
|
||||
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
|
||||
|
||||
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
||||
LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__);
|
||||
LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i,
|
||||
common_token_to_piece(ctx_tgt, i).c_str(),
|
||||
common_token_to_piece(ctx_dft, i).c_str());
|
||||
common_token_to_piece(vocab_tgt, i).c_str(),
|
||||
common_token_to_piece(vocab_dft, i).c_str());
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
@@ -130,7 +130,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64)
|
||||
@@ -1124,7 +1124,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
|
||||
#ifdef GGML_USE_HIP
|
||||
if constexpr (DV <= 128) {
|
||||
if constexpr (DKQ <= 128) {
|
||||
if (Q->ne[1] > 32/ncols2) {
|
||||
constexpr int cols_per_block = 64;
|
||||
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
@@ -1138,7 +1138,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
|
||||
#endif // GGML_USE_HIP
|
||||
|
||||
#ifndef GGML_USE_HIP
|
||||
if constexpr (DV <= 256)
|
||||
if constexpr (DKQ <= 256)
|
||||
#endif // GGML_USE_HIP
|
||||
{
|
||||
if (Q->ne[1] > 16/ncols2) {
|
||||
@@ -1220,11 +1220,22 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
if constexpr (DKQ == 320) { // Mistral Small 4
|
||||
if constexpr (DKQ == 320) {
|
||||
// This branch is only used for Mistral Small 4 which has a GQA ratio of 32.
|
||||
// On AMD, simply use that GQA ratio with 32 columns / block since we always have enough SRAM.
|
||||
// On NVIDIA however, the tile kernel is only used for GPUs that can't use the mma kernel (Pascal and older).
|
||||
// Therefore, use a GQA ratio of 16 with 16 columns / block to stay below 48 kiB of SRAM / block.
|
||||
#ifdef GGML_USE_HIP
|
||||
if (use_gqa_opt && gqa_ratio % 32 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 32, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
#else
|
||||
if (use_gqa_opt && gqa_ratio % 16 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
#endif // GGML_USE_HIP
|
||||
GGML_ABORT("flash-attn tile (320/256): expected GQA ratio multiple of 32");
|
||||
}
|
||||
|
||||
|
||||
@@ -1806,6 +1806,25 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back("U32_DEQUANT_HELPERS");
|
||||
defines.push_back("SRC0_INNER_TYPE=u32");
|
||||
|
||||
switch (context.src0->type) {
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
defines.push_back(type_upper + "_GRID");
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
defines.push_back(type_upper + "_GRID");
|
||||
defines.push_back(type_upper + "_TABLES");
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
variant += std::string("_") + src0_name;
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -1422,7 +1422,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
use_fast = is_vec;
|
||||
use_fast = true;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
|
||||
@@ -740,3 +740,426 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q6_K
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ4_NL
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 18u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
|
||||
let pos = k_in_block % 16u;
|
||||
let nib_shift = (k_in_block / 16u) * 4u;
|
||||
let q_packed = load_u32_at_src0(block_byte_base + 2u + (pos / 4u) * 4u);
|
||||
let nib = (get_byte(q_packed, pos % 4u) >> nib_shift) & 0xFu;
|
||||
|
||||
shmem[elem_idx] = d * f16(kvalues_iq4nl[nib]);
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ4_NL
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ4_XS
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 136u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d_scales_h = load_u32_at_src0(block_byte_base);
|
||||
let d = bitcast<vec2<f16>>(d_scales_h).x;
|
||||
let scales_h = d_scales_h >> 16u;
|
||||
|
||||
let ib = k_in_block / 32u;
|
||||
let pos = k_in_block % 32u;
|
||||
|
||||
let scales_l_word = load_u32_at_src0(block_byte_base + 4u);
|
||||
let ls_lo = (get_byte(scales_l_word, ib / 2u) >> ((ib & 1u) * 4u)) & 0xFu;
|
||||
let ls_hi = ((scales_h >> (2u * ib)) & 3u) << 4u;
|
||||
let dl = d * f16(i32(ls_lo | ls_hi) - 32);
|
||||
|
||||
let iqs = ib * 16u + (pos % 16u);
|
||||
let nib_shift = (pos / 16u) * 4u;
|
||||
let q_packed = load_u32_at_src0(block_byte_base + 8u + (iqs / 4u) * 4u);
|
||||
let nib = (get_byte(q_packed, iqs % 4u) >> nib_shift) & 0xFu;
|
||||
|
||||
shmem[elem_idx] = dl * f16(kvalues_iq4nl[nib]);
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ4_XS
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ1_S
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 50u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
|
||||
let ib = k_in_block / 32u;
|
||||
let pos = k_in_block % 32u;
|
||||
let l = pos / 8u;
|
||||
let j = pos % 8u;
|
||||
|
||||
let qh = load_u32_at_src0(block_byte_base + 34u + ib * 2u) & 0xFFFFu;
|
||||
let dl = d * (2.0 * f32((qh >> 12u) & 7u) + 1.0);
|
||||
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u);
|
||||
|
||||
let qs_w = load_u32_at_src0(block_byte_base + 2u + ib * 4u);
|
||||
let ig = (get_byte(qs_w, l) | (((qh >> (3u * l)) & 7u) << 8u)) * 8u;
|
||||
|
||||
let gw = iq1_grid[(ig + j) / 16u];
|
||||
let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u;
|
||||
let gs = bitcast<i32>(g << 30u) >> 30u;
|
||||
|
||||
shmem[elem_idx] = f16(dl * (f32(gs) + delta));
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ1_S
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ1_M
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 56u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let scales0 = load_u32_at_src0(block_byte_base + 48u);
|
||||
let scales1 = load_u32_at_src0(block_byte_base + 52u);
|
||||
let scale_packed = ((scales0 >> 12u) & 0xFu) |
|
||||
((scales0 >> 24u) & 0x00F0u) |
|
||||
((scales1 >> 4u) & 0x0F00u) |
|
||||
((scales1 >> 16u) & 0xF000u);
|
||||
let d = f32(bitcast<vec2<f16>>(scale_packed).x);
|
||||
|
||||
let ib = k_in_block / 32u;
|
||||
let pos = k_in_block % 32u;
|
||||
let l = pos / 8u;
|
||||
let j = pos % 8u;
|
||||
|
||||
let scales = select(scales0, scales1, ib >= 4u);
|
||||
let sw = (scales >> (16u * ((ib / 2u) % 2u))) & 0xFFFFu;
|
||||
let s_pair = (sw >> (6u * (ib % 2u) + 3u * (l / 2u))) & 0x7u;
|
||||
let dl = d * f32(2u * s_pair + 1u);
|
||||
|
||||
let qh_word = load_u32_at_src0(block_byte_base + 32u + (ib / 2u) * 4u);
|
||||
let qh = qh_word >> (16u * (ib % 2u));
|
||||
let qh_nib = (qh >> (4u * l)) & 0xFu;
|
||||
|
||||
let qs_w = load_u32_at_src0(block_byte_base + ib * 4u);
|
||||
let idx = get_byte(qs_w, l) | ((qh_nib & 7u) << 8u);
|
||||
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_nib & 0x8u) != 0u);
|
||||
|
||||
let ig = idx * 8u;
|
||||
let gw = iq1_grid[(ig + j) / 16u];
|
||||
let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u;
|
||||
let gs = bitcast<i32>(g << 30u) >> 30u;
|
||||
|
||||
shmem[elem_idx] = f16(dl * (f32(gs) + delta));
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ1_M
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ2_XXS
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 66u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
|
||||
let entry_idx = k_in_block / 8u;
|
||||
let j = k_in_block % 8u;
|
||||
|
||||
let ib = entry_idx & ~3u;
|
||||
let l = entry_idx & 3u;
|
||||
|
||||
let aux0 = load_u32_at_src0(block_byte_base + 2u + ib * 2u);
|
||||
let aux1 = load_u32_at_src0(block_byte_base + 2u + (ib + 2u) * 2u);
|
||||
let db = d * (0.5 + f32(aux1 >> 28u)) * 0.25;
|
||||
|
||||
let ig = get_byte(aux0, l) * 8u;
|
||||
let is = (aux1 >> (7u * l)) & 127u;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u);
|
||||
|
||||
let g = get_byte(iq2xxs_grid[(ig + j) / 4u], (ig + j) % 4u);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u);
|
||||
|
||||
shmem[elem_idx] = f16(db * f32(g) * m);
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ2_XXS
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ2_XS
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 74u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
|
||||
let entry_idx = k_in_block / 8u;
|
||||
let j = k_in_block % 8u;
|
||||
|
||||
let ib = entry_idx & ~3u;
|
||||
let l = entry_idx & 3u;
|
||||
|
||||
let scales_word = load_u32_at_src0(block_byte_base + 66u + (ib / 16u) * 4u);
|
||||
let s = get_byte(scales_word, (ib % 16u) / 4u);
|
||||
let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u);
|
||||
let dl = d * (0.5 + f32(s_nib)) * 0.25;
|
||||
|
||||
let qs_word = load_u32_at_src0(block_byte_base + 2u + (ib + l) * 2u);
|
||||
let qs_val = qs_word & 0xFFFFu;
|
||||
let ig = (qs_val & 511u) * 8u;
|
||||
let is = qs_val >> 9u;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u);
|
||||
|
||||
let g = get_byte(iq2xs_grid[(ig + j) / 4u], (ig + j) % 4u);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u);
|
||||
|
||||
shmem[elem_idx] = f16(dl * f32(g) * m);
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ2_XS
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ2_S
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 82u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
|
||||
let ib = k_in_block / 32u;
|
||||
let l = (k_in_block % 32u) / 8u;
|
||||
let j = k_in_block % 8u;
|
||||
|
||||
let scales_word = load_u32_at_src0(block_byte_base + 74u + (ib / 4u) * 4u);
|
||||
let s = get_byte(scales_word, ib % 4u);
|
||||
let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u);
|
||||
let dl = d * (0.5 + f32(s_nib)) * 0.25;
|
||||
|
||||
let qs_word = load_u32_at_src0(block_byte_base + 2u + ib * 4u);
|
||||
let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 4u) * 4u);
|
||||
let qh_b = (get_byte(qh_word, ib % 4u) << (8u - 2u * l)) & 0x300u;
|
||||
let ig = (get_byte(qs_word, l) | qh_b) * 8u;
|
||||
|
||||
let signs_word = load_u32_at_src0(block_byte_base + 34u + ib * 4u);
|
||||
let signs = get_byte(signs_word, l);
|
||||
|
||||
let g = get_byte(iq2s_grid[(ig + j) / 4u], (ig + j) % 4u);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u);
|
||||
|
||||
shmem[elem_idx] = f16(dl * f32(g) * m);
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ2_S
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ3_XXS
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 98u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
|
||||
let ib_pair = k_in_block / 32u;
|
||||
let in_pair = k_in_block % 32u;
|
||||
let l = in_pair / 8u;
|
||||
let in_l = in_pair % 8u;
|
||||
let k2 = in_l / 4u;
|
||||
let j = in_l % 4u;
|
||||
|
||||
let ib = ib_pair * 2u;
|
||||
let sc_sign_off = block_byte_base + 2u + (ib + 32u) * 2u;
|
||||
let sc_sign = load_u32_at_src0(sc_sign_off);
|
||||
let db = d * (0.5 + f32(sc_sign >> 28u)) * 0.5;
|
||||
let is = (sc_sign >> (7u * l)) & 127u;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u);
|
||||
|
||||
let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 2u + l) * 2u) & 0xFFFFu;
|
||||
let ig_byte = get_byte(ig_word, k2);
|
||||
let g = get_byte(iq3xxs_grid[ig_byte], j);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u);
|
||||
|
||||
shmem[elem_idx] = f16(db * f32(g) * m);
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ3_XXS
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ3_S
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 110u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
|
||||
let ib = k_in_block / 64u;
|
||||
let rest = k_in_block % 64u;
|
||||
let k = rest / 32u;
|
||||
let in_k = rest % 32u;
|
||||
let l = in_k / 8u;
|
||||
let in_l = in_k % 8u;
|
||||
let k2 = in_l / 4u;
|
||||
let j = in_l % 4u;
|
||||
|
||||
let scales_word = load_u32_at_src0(block_byte_base + 106u);
|
||||
let s = get_byte(scales_word, ib);
|
||||
let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, k != 0u);
|
||||
let dl = d * (1.0 + 2.0 * f32(s_nib));
|
||||
|
||||
let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 2u) * 4u);
|
||||
let qh_byte = get_byte(qh_word, (ib % 2u) * 2u + k);
|
||||
|
||||
let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 8u + k * 4u + l) * 2u) & 0xFFFFu;
|
||||
let ig_lo = get_byte(ig_word, 0u) | ((qh_byte << (8u - 2u * l)) & 256u);
|
||||
let ig_hi = get_byte(ig_word, 1u) | ((qh_byte << (7u - 2u * l)) & 256u);
|
||||
let ig = select(ig_lo, ig_hi, k2 != 0u);
|
||||
|
||||
let signs_word = load_u32_at_src0(block_byte_base + 74u + (ib * 2u + k) * 4u);
|
||||
let signs = get_byte(signs_word, l);
|
||||
|
||||
let g = get_byte(iq3s_grid[ig], j);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u);
|
||||
|
||||
shmem[elem_idx] = f16(dl * f32(g) * m);
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ3_S
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
HTTPLIB_VERSION = "refs/tags/v0.43.1"
|
||||
HTTPLIB_VERSION = "refs/tags/v0.43.2"
|
||||
|
||||
vendor = {
|
||||
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
|
||||
|
||||
58
scripts/wc2wt.sh
Executable file
58
scripts/wc2wt.sh
Executable file
@@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# initialize a new worktree from a branch name:
|
||||
#
|
||||
# - creates a new branch from current HEAD
|
||||
# - creates a new worktree in a parent folder, suffixed with the branch name
|
||||
#
|
||||
# sample usage:
|
||||
# ./scripts/wc2wt.sh gg/new-feature-foo-bar
|
||||
# ./scripts/wc2wt.sh gg/new-feature-foo-bar opencode
|
||||
# ./scripts/wc2wt.sh gg/new-feature-foo-bar "cmake -B build && cmake --build build"
|
||||
# ./scripts/wc2wt.sh gg/new-feature-foo-bar "bash -l"
|
||||
|
||||
function usage() {
|
||||
echo "usage: $0 <branch_name> [cmd]"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# check we are in the right directory
|
||||
if [[ ! -f "scripts/wc2wt.sh" ]]; then
|
||||
echo "error: this script must be run from the root of the repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ $# -lt 1 || $# -gt 2 ]]; then
|
||||
usage
|
||||
fi
|
||||
|
||||
BRANCH=$1
|
||||
|
||||
if [[ -z "$BRANCH" ]]; then
|
||||
echo "error: branch name must not be empty"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
dir=$(basename $(pwd))
|
||||
# sanitize branch name for directory name (replace / with -)
|
||||
dir_suffix=$(echo "$BRANCH" | tr '/' '-')
|
||||
|
||||
git worktree add -b "$BRANCH" "../$dir-$dir_suffix" HEAD
|
||||
|
||||
og_path=$(pwd)
|
||||
wt_path=$(cd "../$dir-$dir_suffix" && pwd)
|
||||
|
||||
echo "git worktree created in $wt_path"
|
||||
|
||||
cd "$wt_path"
|
||||
|
||||
# pi agent setup in the worktree
|
||||
if [[ -f "$og_path/.pi/SYSTEM.md" && ! -f ".pi/SYSTEM.md" ]]; then
|
||||
mkdir -p .pi
|
||||
ln -sfn "$og_path/.pi/SYSTEM.md" .pi/SYSTEM.md
|
||||
fi
|
||||
|
||||
if [[ $# -eq 2 ]]; then
|
||||
echo "executing: $2"
|
||||
eval "$2"
|
||||
fi
|
||||
@@ -680,6 +680,7 @@ private:
|
||||
// slots / clients
|
||||
std::vector<server_slot> slots;
|
||||
|
||||
int trace = 0;
|
||||
int slots_debug = 0;
|
||||
int n_empty_consecutive = 0;
|
||||
|
||||
@@ -918,12 +919,21 @@ private:
|
||||
slot.reset();
|
||||
}
|
||||
|
||||
{
|
||||
const char * LLAMA_TRACE = getenv("LLAMA_TRACE");
|
||||
trace = LLAMA_TRACE ? atoi(LLAMA_TRACE) : 0;
|
||||
|
||||
if (trace) {
|
||||
SRV_WRN("LLAMA_TRACE = %d\n", trace);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG");
|
||||
slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0;
|
||||
|
||||
if (slots_debug) {
|
||||
SRV_WRN("slots debug = %d\n", slots_debug);
|
||||
SRV_WRN("LLAMA_SERVER_SLOTS_DEBUG = %d\n", slots_debug);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2974,13 +2984,15 @@ private:
|
||||
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft);
|
||||
slot.spec_i_batch.clear();
|
||||
|
||||
SLT_DBG(slot, "%s: n_draft=%zu, accepted=%zu\n", __func__, slot.spec_draft.size(), accepted.size());
|
||||
|
||||
GGML_ASSERT(accepted.size() >= 1);
|
||||
|
||||
// check for partial draft acceptance
|
||||
if (accepted.size() < slot.spec_draft.size() + 1) {
|
||||
if (use_ckpt) {
|
||||
if (trace > 0) {
|
||||
SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size());
|
||||
}
|
||||
|
||||
// partial acceptance is not supported by the context -> truncate the draft and restore the state
|
||||
slot.spec_draft = std::move(accepted);
|
||||
|
||||
@@ -3002,8 +3014,10 @@ private:
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
LOG_DBG("%s: partial acceptance: %zu < %zu\n", __func__, accepted.size(), slot.spec_draft.size());
|
||||
if (trace > 0) {
|
||||
SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft);
|
||||
}
|
||||
|
||||
common_speculative_accept(slot.spec.get(), accepted.size() - 1);
|
||||
|
||||
68
vendor/cpp-httplib/httplib.cpp
vendored
68
vendor/cpp-httplib/httplib.cpp
vendored
@@ -1464,8 +1464,9 @@ bool mmap::open(const char *path) {
|
||||
auto wpath = u8string_to_wstring(path);
|
||||
if (wpath.empty()) { return false; }
|
||||
|
||||
hFile_ = ::CreateFile2(wpath.c_str(), GENERIC_READ, FILE_SHARE_READ,
|
||||
OPEN_EXISTING, NULL);
|
||||
hFile_ =
|
||||
::CreateFile2(wpath.c_str(), GENERIC_READ,
|
||||
FILE_SHARE_READ | FILE_SHARE_WRITE, OPEN_EXISTING, NULL);
|
||||
|
||||
if (hFile_ == INVALID_HANDLE_VALUE) { return false; }
|
||||
|
||||
@@ -2052,56 +2053,50 @@ int getaddrinfo_with_timeout(const char *node, const char *service,
|
||||
return 0;
|
||||
#elif defined(_GNU_SOURCE) && defined(__GLIBC__) && \
|
||||
(__GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 2))
|
||||
// Linux implementation using getaddrinfo_a for asynchronous DNS resolution
|
||||
struct gaicb request;
|
||||
// #2431: gai_cancel() is non-blocking and may return EAI_NOTCANCELED while
|
||||
// the resolver worker still references the stack-local gaicb. The cancel
|
||||
// path therefore waits (gai_suspend with no timeout) for the worker to
|
||||
// actually finish before letting the stack frame go. The trade-off is that
|
||||
// a wedged DNS server can hold this thread for the system resolver timeout
|
||||
// (~30s by default) past the caller's connection timeout.
|
||||
struct gaicb request {};
|
||||
struct gaicb *requests[1] = {&request};
|
||||
struct sigevent sevp;
|
||||
struct timespec timeout;
|
||||
struct sigevent sevp {};
|
||||
struct timespec timeout {
|
||||
timeout_sec, 0
|
||||
};
|
||||
|
||||
// Initialize the request structure
|
||||
memset(&request, 0, sizeof(request));
|
||||
request.ar_name = node;
|
||||
request.ar_service = service;
|
||||
request.ar_request = hints;
|
||||
|
||||
// Set up timeout
|
||||
timeout.tv_sec = timeout_sec;
|
||||
timeout.tv_nsec = 0;
|
||||
|
||||
// Initialize sigevent structure (not used, but required)
|
||||
memset(&sevp, 0, sizeof(sevp));
|
||||
sevp.sigev_notify = SIGEV_NONE;
|
||||
|
||||
// Start asynchronous resolution
|
||||
int start_result = getaddrinfo_a(GAI_NOWAIT, requests, 1, &sevp);
|
||||
if (start_result != 0) { return start_result; }
|
||||
int rc = getaddrinfo_a(GAI_NOWAIT, requests, 1, &sevp);
|
||||
if (rc != 0) { return rc; }
|
||||
|
||||
// Wait for completion with timeout
|
||||
int wait_result =
|
||||
gai_suspend((const struct gaicb *const *)requests, 1, &timeout);
|
||||
auto cleanup = scope_exit([&] {
|
||||
if (request.ar_result) { freeaddrinfo(request.ar_result); }
|
||||
});
|
||||
|
||||
int wait_result = gai_suspend(requests, 1, &timeout);
|
||||
|
||||
if (wait_result == 0 || wait_result == EAI_ALLDONE) {
|
||||
// Completed successfully, get the result
|
||||
int gai_result = gai_error(&request);
|
||||
if (gai_result == 0) {
|
||||
*res = request.ar_result;
|
||||
request.ar_result = nullptr;
|
||||
return 0;
|
||||
} else {
|
||||
// Clean up on error
|
||||
if (request.ar_result) { freeaddrinfo(request.ar_result); }
|
||||
return gai_result;
|
||||
}
|
||||
} else if (wait_result == EAI_AGAIN) {
|
||||
// Timeout occurred, cancel the request
|
||||
gai_cancel(&request);
|
||||
return EAI_AGAIN;
|
||||
} else {
|
||||
// Other error occurred
|
||||
gai_cancel(&request);
|
||||
return wait_result;
|
||||
return gai_result;
|
||||
}
|
||||
|
||||
gai_cancel(&request);
|
||||
while (gai_error(&request) == EAI_INPROGRESS) {
|
||||
gai_suspend(requests, 1, nullptr);
|
||||
}
|
||||
return wait_result;
|
||||
#else
|
||||
// Fallback implementation using thread-based timeout for other Unix systems
|
||||
// Fallback implementation using thread-based timeout for other Unix systems.
|
||||
|
||||
struct GetAddrInfoState {
|
||||
~GetAddrInfoState() {
|
||||
@@ -14142,6 +14137,9 @@ ssize_t read(session_t session, void *buf, size_t len, TlsError &err) {
|
||||
err.code = impl::map_mbedtls_error(ret, err.sys_errno);
|
||||
err.backend_code = static_cast<uint64_t>(-ret);
|
||||
impl::mbedtls_last_error() = ret;
|
||||
// mbedTLS signals a clean close_notify via a negative error code rather
|
||||
// than 0; surface it as a clean EOF the way OpenSSL/wolfSSL do.
|
||||
if (err.code == ErrorCode::PeerClosed) { return 0; }
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
4
vendor/cpp-httplib/httplib.h
vendored
4
vendor/cpp-httplib/httplib.h
vendored
@@ -8,8 +8,8 @@
|
||||
#ifndef CPPHTTPLIB_HTTPLIB_H
|
||||
#define CPPHTTPLIB_HTTPLIB_H
|
||||
|
||||
#define CPPHTTPLIB_VERSION "0.43.1"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x002b01"
|
||||
#define CPPHTTPLIB_VERSION "0.43.2"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x002b02"
|
||||
|
||||
#ifdef _WIN32
|
||||
#if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0A00
|
||||
|
||||
Reference in New Issue
Block a user