sampling : delegate input allocation to the scheduler (#19266)

* sampling : delegate input allocation to the scheduler

* graph : compute backend samplers only if needed
This commit is contained in:
Georgi Gerganov
2026-02-03 22:16:16 +02:00
committed by GitHub
parent 32b17abdb0
commit faa1bc26ee
3 changed files with 33 additions and 73 deletions

View File

@@ -1027,11 +1027,7 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
llama_sampler_chain_n(sampler) > 0;
if (sampler && can_offload) {
ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output());
auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output());
if (host_buft) {
buft = host_buft;
}
auto * buft = ggml_backend_dev_buffer_type(model.dev_output());
sampler->iface->backend_init(sampler, buft);

View File

@@ -2419,6 +2419,9 @@ void llm_graph_context::build_sampling() const {
return;
}
std::array<ggml_tensor *, 2> outs;
outs[0] = res->t_logits;
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
res->add_input(std::move(inp_sampling));
@@ -2439,14 +2442,14 @@ void llm_graph_context::build_sampling() const {
// add a dummy row of logits
// this trick makes the graph static, regardless of which samplers are activated
// this is important in order to minimize graph reallocations
// TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550)
ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
for (const auto & [seq_id, sampler] : samplers) {
const auto it = seq_to_logit_row.find(seq_id);
// inactive samplers always work on the first row
const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0;
const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0;
const int i_out = it != seq_to_logit_row.end() ? 1 : 0;
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
@@ -2463,22 +2466,26 @@ void llm_graph_context::build_sampling() const {
if (data.sampled != nullptr) {
res->t_sampled[seq_id] = data.sampled;
ggml_build_forward_expand(gf, data.sampled);
outs[1] = data.sampled;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
}
if (data.probs != nullptr) {
res->t_sampled_probs[seq_id] = data.probs;
ggml_build_forward_expand(gf, data.probs);
outs[1] = data.probs;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
}
if (data.logits != nullptr) {
res->t_sampled_logits[seq_id] = data.logits;
ggml_build_forward_expand(gf, data.logits);
outs[1] = data.logits;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
}
if (data.candidates != nullptr) {
res->t_candidates[seq_id] = data.candidates;
ggml_build_forward_expand(gf, data.candidates);
outs[1] = data.candidates;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
}
}

View File

@@ -1025,11 +1025,7 @@ struct llama_sampler_dist : public llama_sampler_backend {
std::mt19937 rng;
// backend input
struct ggml_tensor * inp_uniform;
ggml_context_ptr inp_ctx;
ggml_backend_buffer_ptr inp_buf;
ggml_tensor * inp_uniform;
};
static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) {
@@ -1138,37 +1134,10 @@ static bool llama_sampler_dist_backend_init(
ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_dist *) smpl->ctx;
// allocate inputs
{
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
sctx->inp_ctx.reset(ggml_init(params));
// Create the uniform random scalar input tensor. This will be set by
// llama_sampler_dist_backend_set_input after this graph is built.
sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1);
ggml_set_name (sctx->inp_uniform, "uniform");
ggml_set_input(sctx->inp_uniform);
// Allocate all tensors from our context to the backend
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
ggml_backend_buffer_clear(sctx->inp_buf.get(), 0);
}
const bool res = llama_sampler_backend_support(smpl, buft);
sctx->init(res);
if (!res) {
sctx->inp_ctx.reset(nullptr);
sctx->inp_buf.reset(nullptr);
}
return res;
}
@@ -1178,8 +1147,13 @@ static void llama_sampler_dist_backend_apply(
struct ggml_cgraph * gf,
struct llama_sampler_data * data) {
GGML_UNUSED(gf);
auto * sctx = (llama_sampler_dist *) smpl->ctx;
sctx->inp_uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
ggml_set_name (sctx->inp_uniform, "uniform");
ggml_set_input(sctx->inp_uniform);
struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
ggml_set_name(probs, "dist_probs");
@@ -1226,6 +1200,7 @@ static void llama_sampler_dist_backend_apply(
static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_dist *) smpl->ctx;
GGML_ASSERT(sctx->inp_uniform != nullptr);
// We sample in double precision and cast to float to match rnd numbers of
@@ -1262,8 +1237,6 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
/* .inp_uniform = */ nullptr,
/* .inp_ctx = */ nullptr,
/* .inp_buf = */ nullptr,
}
);
}
@@ -3461,9 +3434,6 @@ struct llama_sampler_logit_bias : public llama_sampler_backend {
struct ggml_tensor * inp_logit_bias;
struct ggml_tensor * inp_logit_idxs;
ggml_context_ptr inp_ctx;
ggml_backend_buffer_ptr inp_buf;
};
static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) {
@@ -3526,6 +3496,16 @@ static void llama_sampler_logit_bias_backend_apply(
return;
}
const size_t n = sctx->logit_bias.size();
sctx->inp_logit_bias = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n);
ggml_set_name(sctx->inp_logit_bias, "logit_bias");
ggml_set_input(sctx->inp_logit_bias);
sctx->inp_logit_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n);
ggml_set_name(sctx->inp_logit_idxs, "logit_idxs");
ggml_set_input(sctx->inp_logit_idxs);
ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f);
cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur));
@@ -3562,6 +3542,8 @@ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * sm
static bool llama_sampler_logit_bias_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
GGML_UNUSED(buft);
auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
sctx->init(true);
@@ -3570,29 +3552,6 @@ static bool llama_sampler_logit_bias_backend_init(
return true;
}
ggml_init_params params = {
/*.mem_size =*/ 2*ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
sctx->inp_ctx.reset(ggml_init(params));
const size_t n = sctx->logit_bias.size();
sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n);
ggml_set_name(sctx->inp_logit_bias, "logit_bias");
ggml_set_input(sctx->inp_logit_bias);
sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n);
ggml_set_name(sctx->inp_logit_idxs, "logit_idxs");
ggml_set_input(sctx->inp_logit_idxs);
// Allocate all tensors from our context to the backend
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
ggml_backend_buffer_clear(sctx->inp_buf.get(), 0);
return true;
}
@@ -3628,8 +3587,6 @@ struct llama_sampler * llama_sampler_init_logit_bias(
/* .to_search = */ {},
/* .inp_logit_bias = */ nullptr,
/* .inp_logit_idxs = */ nullptr,
/* .inp_ctx = */ nullptr,
/* .inp_buf = */ nullptr,
}
);
}