experiments

This commit is contained in:
Georgi Gerganov
2026-01-28 09:45:07 +02:00
parent 003c90352d
commit 6c8a04576e
6 changed files with 189 additions and 17 deletions

View File

@@ -3398,7 +3398,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--spec-draftless"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v]",
{"--spec-draftless"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-map-mod]",
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
common_speculative_type_to_str(params.speculative.type).c_str()),
[](common_params & params, const std::string & value) {
@@ -3412,6 +3412,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
} else if (value == "ngram-map-k4v") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
} else if (value == "ngram-map-mod") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_MOD;
} else {
throw std::invalid_argument("unknown speculative decoding type without draft model");
}

View File

@@ -171,6 +171,7 @@ enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_MOD,
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, // self-speculative decoding with 3-level n-gram cache
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
};

View File

@@ -7,6 +7,21 @@
#include <cstdio>
#include <sstream>
// Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...].
static std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length) {
std::ostringstream oss;
oss << '[';
for (size_t i = 0; i < length; ++i) {
if (i > 0) {
oss << ", ";
}
oss << inp[start + i];
}
oss << ']';
return oss.str();
}
// n-gram simple
//
@@ -100,8 +115,6 @@ llama_tokens common_ngram_simple_draft(
// maximum number of counted values of a ngram map value.
#define COMMON_NGRAM_MAX_VALUE_COUNT 16380
static std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length);
void common_ngram_map_draft(common_ngram_map & map,
const llama_tokens & inp, llama_token sampled,
llama_tokens & draft) {
@@ -348,20 +361,97 @@ void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) {
curr_value.n_accepted = n_accepted;
}
// Helper functions.
//
// n-gram mod
//
// Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...].
std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length) {
std::ostringstream oss;
oss << '[';
for (size_t i = 0; i < length; ++i) {
if (i > 0) {
oss << ", ";
}
oss << inp[start + i];
common_ngram_mod::common_ngram_mod(uint16_t m) : m(m) {
int64_t n = 1;
for (int32_t i = 0; i < N_MODS; ++i) {
n *= mods[i];
}
oss << ']';
return oss.str();
entries.resize(n);
const size_t size_bytes = entries.size() * sizeof(common_ngram_mod_entry);
LOG_INF("%s: size = %.3f MB\n", __func__, size_bytes / (1024.0 * 1024.0));
}
void common_ngram_mod::add(const llama_token * tokens) {
const uint64_t i = idx(tokens);
common_ngram_mod_entry & entry = entries[i];
if (entry.n_choices < COMMON_NGRAM_MOD_MAX_CHOICES) {
entry.n_choices++;
}
entry.choices[entry.head] = tokens[N_MODS];
entry.head = (entry.head + 1) % COMMON_NGRAM_MOD_MAX_CHOICES;
}
llama_token common_ngram_mod::get(const llama_token * tokens, int32_t offs) const {
const uint64_t i = idx(tokens);
const common_ngram_mod_entry & entry = entries[i];
if (entry.n_choices == 0) {
return LLAMA_TOKEN_NULL;
}
const int32_t k = (offs + entry.head) % entry.n_choices;
return entry.choices[k];
}
uint64_t common_ngram_mod::idx(const llama_token * tokens) {
uint64_t rh = 0;
uint64_t res = 0;
for (uint64_t i = 0; i < N_MODS; ++i) {
rh = rh * 31 + tokens[i];
res = res * mods[i] + (rh % mods[i]);
}
return res;
}
void common_ngram_mod_draft(
common_ngram_mod & mod,
const llama_tokens & inp,
llama_token sampled,
llama_tokens & draft) {
const size_t N_MODS = common_ngram_mod::N_MODS;
const size_t cur_len = inp.size();
if (cur_len < N_MODS) {
return;
}
if (mod.n_calls++ % 64 == 0) {
const size_t n_start = (256*(mod.n_calls/64)) % GGML_PAD(cur_len, 256);
for (size_t i = 0; i < 256 && n_start + i < cur_len - N_MODS; ++i) {
mod.add(inp.data() + n_start + i);
}
}
draft.resize(N_MODS + mod.m);
for (size_t i = 0; i < N_MODS - 1; ++i) {
draft[i] = inp[cur_len - N_MODS + 1 + i];
}
draft[N_MODS - 1] = sampled;
for (size_t i = 0; i < mod.m; ++i) {
const llama_token token = mod.get(draft.data() + i, cur_len + i);
if (token == LLAMA_TOKEN_NULL) {
draft.clear();
return;
}
draft[N_MODS + i] = token;
}
// only return the m tokens that were drafted
for (size_t i = 0; i < mod.m; ++i) {
draft[i] = draft[N_MODS + i];
}
draft.resize(mod.m);
}

View File

@@ -11,6 +11,7 @@
//
#include "llama.h"
#include "common.h"
#include <vector>
@@ -103,3 +104,40 @@ void common_ngram_map_draft(
// Update the statistics of a value after a draft was processed.
void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted);
//
// n-gram mod
//
#define COMMON_NGRAM_MOD_MAX_CHOICES 4
struct common_ngram_mod_entry {
uint32_t head = 0;
uint32_t n_choices = 0;
llama_token choices[COMMON_NGRAM_MOD_MAX_CHOICES];
};
struct common_ngram_mod {
common_ngram_mod(uint16_t m);
void add(const llama_token * tokens);
llama_token get(const llama_token * tokens, int32_t offs) const;
uint64_t n_calls = 0;
uint16_t m;
std::vector<common_ngram_mod_entry> entries;
static constexpr int32_t N_MODS = 17;
static constexpr int32_t mods[N_MODS] = { 2, 1, 1, 1, 8, 1, 1, 1, 16, 1, 1, 1, 32, 1, 1, 1, 64, };
static uint64_t idx(const llama_token * tokens);
};
void common_ngram_mod_draft(
common_ngram_mod & mod,
const llama_tokens & inp,
llama_token sampled,
llama_tokens & draft);

View File

@@ -23,6 +23,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_MOD,
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE
};
@@ -33,6 +34,7 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
{"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
{"ngram_map_mod", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_MOD},
{"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE}
};
@@ -232,6 +234,15 @@ struct common_speculative_state_ngram_map_k4v : public common_speculative_state_
: common_speculative_state_ngram_map_k(type, std::move(map)) {}
};
struct common_speculative_state_ngram_mod : public common_speculative_state {
common_ngram_mod mod;
common_speculative_state_ngram_mod(
enum common_speculative_type type,
common_ngram_mod mod)
: common_speculative_state(type), mod(std::move(mod)) {}
};
struct common_speculative_state_ngram_cache : public common_speculative_state {
uint16_t n_draft;
bool save_dynamic;
@@ -323,6 +334,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_MOD: return "ngram_map_mod";
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache";
default: return "unknown";
}
@@ -362,6 +374,7 @@ struct common_speculative * common_speculative_init(
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
bool has_ngram_map_k = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K);
bool has_ngram_map_k4v = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V);
bool has_ngram_map_mod = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_MOD);
// In a more complex implementation we could use the same implementation but with different parameters.
// This was initially used in PR-18471 but removed to simplify the code.
@@ -376,6 +389,9 @@ struct common_speculative * common_speculative_init(
// This implementation can guess tokens with high acceptance rate but is more expensive.
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params));
}
if (has_ngram_map_mod) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_MOD, params));
}
if (has_ngram_cache) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
}
@@ -434,8 +450,16 @@ struct common_speculative * common_speculative_init(
}
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: {
implementations.push_back(std::make_unique<common_speculative_state_ngram_map_k4v>(
(config.type),
get_common_ngram_map(config)
(config.type),
get_common_ngram_map(config)
));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_MOD: {
common_ngram_mod mod(config.params.ngram_size_m);
implementations.push_back(std::make_unique<common_speculative_state_ngram_mod>(
(config.type),
std::move(mod)
));
break;
}
@@ -794,6 +818,15 @@ llama_tokens common_speculative_gen_draft(
GGML_ABORT("unexpected implementation in type %d", impl.get()->type);
}
} break;
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_MOD:
{
auto * state = dynamic_cast<common_speculative_state_ngram_mod *>(impl.get());
if (state) {
common_ngram_mod_draft(state->mod, prompt_tgt, id_last, result);
} else {
GGML_ABORT("unexpected implementation in type %d", impl.get()->type);
}
} break;
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE:
{
auto * state = dynamic_cast<common_speculative_state_ngram_cache *>(impl.get());
@@ -842,6 +875,8 @@ void common_speculative_accept(struct common_speculative * spec, uint16_t n_acce
impl->drafts_accepted_tokens += n_accepted;
}
LOG_WRN("XXXXXXXXXXXXX n_accepted = %d\n", n_accepted);
if (impl->type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K ||
impl->type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V) {

View File

@@ -2039,9 +2039,15 @@ private:
/*.params_spec.n_draft =*/ n_draft_max,
/*.params_spec.p_min =*/ slot.task->params.speculative.p_min,
};
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
if (draft.size() > 0) {
std::string tmp = common_detokenize(slot.ctx, draft);
//LOG_WRN("XXXXXX: draft: '%s'\n", tmp.c_str());
}
// add the sampled token to the batch
slot.i_batch_dft.push_back(batch.n_tokens);
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);