Compare commits

..

9 Commits
b6860 ... b6869

Author SHA1 Message Date
YaelGitAccount
851553ea6b cuda: add SET operation support (#16804)
* feat(cuda): add GGML_OP_SET support

Implement CUDA kernel for SET operation with f32 support.

All tests passing (14598/14598).

* cuda(set): add I32 support; keep F32

* refactor(cuda): use ggml_cuda_cpy to unify SET operator logic and remove code duplication

* Update ggml/src/ggml-cuda/ggml-cuda.cu

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update ggml/src/ggml-cuda/set.cu

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-10-28 20:10:28 +01:00
Georgi Gerganov
85a7d8677b memory : remove KV cache size padding (#16812)
* memory : remove KV cache size padding

* cont : restore padding for n_kv tensor shape

* server : use slot context size instead of training context size

* server : simplify context limit logic
2025-10-28 20:19:44 +02:00
Georgi Gerganov
a8ca18b4b8 llama-bench : clarify benchmarked parts of the computation (#16823) 2025-10-28 19:41:43 +02:00
l3utterfly
8284efc35c initialise buffer.device in ggml_hexagon_session (#16816) 2025-10-28 08:16:20 -07:00
Sam Malayek
1c1409e131 embedding: add raw option for --embd-output-format (#16541)
* Add --embd-output-format raw for plain numeric embedding output

This new option outputs embeddings as raw space-separated floats, without JSON or 'embedding N:' prefixes. Useful for downstream vector pipelines and scripting.

* Move raw output handling into format handling section

* Move raw output handling into else-if block with other format handlers

* Use LOG instead of printf for raw embedding output

* docs: document 'raw' embedding output format in arg.cpp and README
2025-10-28 12:51:41 +02:00
Johannes Gäßler
7a0e900e36 llama: consistent ctx <-> buf order for KV cache (#16746) 2025-10-28 11:23:54 +01:00
Aldehir Rojas
280d97be96 grammar : support array references in json schema (#16792)
* grammar : support array references in json schema

* Update json-schema-to-grammar.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* grammar : improve regex when naming ref derived rules

* grammar : replace non-conformant definitions array with anyOf test case

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-10-28 09:37:52 +01:00
Chenguang Li
3479efd112 CANN: Improve device ID handling and aclnnArange checks (#16752)
* cann: improve device ID handling and aclnnArange checks

- Stop relying on CANN's internal device ID retrieval; use a global variable instead.
- Enforce stricter dimension validation in aclnnArange for better compatibility across CANN versions.

* cann: use thread local var
2025-10-28 10:54:53 +08:00
Aman Gupta
463bbf20bf CUDA: add unused vars to mmvf and mmvq (#16807) 2025-10-28 10:31:21 +08:00
24 changed files with 262 additions and 117 deletions

View File

@@ -3248,7 +3248,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
add_opt(common_arg(
{"--embd-output-format"}, "FORMAT",
"empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix",
"empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)",
[](common_params & params, const std::string & value) {
params.embd_out = value;
}

View File

@@ -601,7 +601,10 @@ private:
}
std::string _resolve_ref(const std::string & ref) {
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
auto it = ref.find('#');
std::string ref_fragment = it != std::string::npos ? ref.substr(it + 1) : ref;
static const std::regex nonalphanumeric_regex(R"([^a-zA-Z0-9-]+)");
std::string ref_name = "ref" + std::regex_replace(ref_fragment, nonalphanumeric_regex, "-");
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
_refs_being_resolved.insert(ref);
json resolved = _refs[ref];
@@ -774,11 +777,24 @@ public:
std::vector<std::string> tokens = string_split(pointer, "/");
for (size_t i = 1; i < tokens.size(); ++i) {
std::string sel = tokens[i];
if (target.is_null() || !target.contains(sel)) {
if (target.is_object() && target.contains(sel)) {
target = target[sel];
} else if (target.is_array()) {
size_t sel_index;
try {
sel_index = std::stoul(sel);
} catch (const std::invalid_argument & e) {
sel_index = target.size();
}
if (sel_index >= target.size()) {
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
return;
}
target = target[sel_index];
} else {
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
return;
}
target = target[sel];
}
_refs[ref] = target;
}

View File

@@ -38,6 +38,7 @@ The above command will output space-separated float values.
| | multiple embeddings | $[[x_1,...,x_n],[x_1,...,x_n],...,[x_1,...,x_n]]$
| 'json' | openai style |
| 'json+' | add cosine similarity matrix |
| 'raw' | plain text output |
### --embd-separator $"string"$
| $"string"$ | |

View File

@@ -70,6 +70,29 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}
}
// plain, pipe-friendly output: one embedding per line
static void print_raw_embeddings(const float * emb,
int n_embd_count,
int n_embd,
const llama_model * model,
enum llama_pooling_type pooling_type,
int embd_normalize) {
const uint32_t n_cls_out = llama_model_n_cls_out(model);
const bool is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK);
const int cols = is_rank ? std::min<int>(n_embd, (int) n_cls_out) : n_embd;
for (int j = 0; j < n_embd_count; ++j) {
for (int i = 0; i < cols; ++i) {
if (embd_normalize == 0) {
LOG("%1.0f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
} else {
LOG("%1.7f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
}
}
LOG("\n");
}
}
int main(int argc, char ** argv) {
common_params params;
@@ -372,6 +395,8 @@ int main(int argc, char ** argv) {
}
if (notArray) LOG("\n}\n");
} else if (params.embd_out == "raw") {
print_raw_embeddings(emb, n_embd_count, n_embd, model, pooling_type, params.embd_normalize);
}
LOG("\n");

View File

@@ -371,8 +371,17 @@ class SchemaConverter:
raise ValueError(f'Unsupported ref {ref}')
for sel in ref.split('#')[-1].split('/')[1:]:
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
target = target[sel]
assert target is not None, f'Error resolving ref {ref}: {sel} not in {target}'
if isinstance(target, list):
try:
sel_index = int(sel)
except ValueError:
raise ValueError(f'Error resolving ref {ref}: {sel} not in {target}')
assert 0 <= sel_index < len(target), f'Error resolving ref {ref}: {sel} not in {target}'
target = target[sel_index]
else:
assert sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
target = target[sel]
self._refs[ref] = target
else:
@@ -547,7 +556,8 @@ class SchemaConverter:
def _resolve_ref(self, ref):
ref_name = ref.split('/')[-1]
ref_fragment = ref.split('#')[-1]
ref_name = 'ref' + re.sub(r'[^a-zA-Z0-9-]+', '-', ref_fragment)
if ref_name not in self._rules and ref not in self._refs_being_resolved:
self._refs_being_resolved.add(ref)
resolved = self._refs[ref]

View File

@@ -2234,7 +2234,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
ACL_MEM_MALLOC_HUGE_FIRST));
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
theta_scale_ne, theta_scale_nb, 1);
float start = 0;
float step = 1;
@@ -2251,7 +2251,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
void * yarn_ramp_buffer = yarn_ramp_allocator.get();
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne,
theta_scale_nb, GGML_MAX_DIMS);
theta_scale_nb, 1);
float zero_value = 0, one_value = 1;
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
aclScalar * low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);

View File

@@ -67,19 +67,30 @@
GGML_ABORT("CANN error");
}
// Thread-local variable to record the current device of this thread.
thread_local int g_current_cann_device = -1;
/**
* @brief Sets the device to be used by CANN.
* @brief Set the CANN device to be used.
*
* @param device The device ID to set.
* @param device The target device ID to set.
*/
void ggml_cann_set_device(const int32_t device) {
int current_device = -1;
aclrtGetDevice(&current_device);
// int current_device = -1;
// Note: In some CANN versions, if no device has been set yet,
// aclrtGetDevice(&current_device) may return 0 by default.
// aclrtGetDevice(&current_device);
if (device == current_device) {
// If the current device is already the target one, no need to switch.
if (device == g_current_cann_device) {
return;
}
// Switch to the new device.
ACL_CHECK(aclrtSetDevice(device));
// Update the global device record.
g_current_cann_device = device;
}
/**

View File

@@ -50,6 +50,7 @@
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/wkv.cuh"
#include "ggml-cuda/gla.cuh"
#include "ggml-cuda/set.cuh"
#include "ggml-cuda/set-rows.cuh"
#include "ggml-cuda/pad_reflect_1d.cuh"
#include "ggml.h"
@@ -2416,6 +2417,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SET_ROWS:
ggml_cuda_op_set_rows(ctx, dst);
break;
case GGML_OP_SET:
ggml_cuda_op_set(ctx, dst);
break;
case GGML_OP_DUP:
ggml_cuda_dup(ctx, dst);
break;
@@ -3842,6 +3846,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
op->src[0]->type == GGML_TYPE_F32 &&
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
} break;
case GGML_OP_SET:
{
const ggml_type t = op->type;
return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) &&
t == op->src[0]->type &&
t == op->src[1]->type;
} break;
case GGML_OP_CPY:
{
ggml_type src0_type = op->src[0]->type;

View File

@@ -343,6 +343,10 @@ static __global__ void mul_mat_vec_f(
}
dst[tid*stride_col_dst + row] = value;
if constexpr (!has_fusion) {
GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, glu_op, gate_x, x_bias, gate_bias, sumf_gate);
}
}
template<typename T, typename type_acc, int ncols_dst, int block_size>

View File

@@ -310,6 +310,10 @@ static __global__ void mul_mat_vec_q(
dst[j*stride_col_dst + threadIdx.x] = result;
}
}
if constexpr (!has_fusion) {
GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, active_glu, gate_bias, x_bias, tmp_gate);
}
}
static std::pair<dim3, dim3> calc_launch_params(

39
ggml/src/ggml-cuda/set.cu Normal file
View File

@@ -0,0 +1,39 @@
#include "set.cuh"
#include "cpy.cuh"
void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32));
GGML_ASSERT(src1->type == src0->type);
GGML_ASSERT(dst ->type == src0->type);
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
const size_t nb1 = ((int32_t *) dst->op_params)[0];
const size_t nb2 = ((int32_t *) dst->op_params)[1];
const size_t nb3 = ((int32_t *) dst->op_params)[2];
const size_t offset = ((int32_t *) dst->op_params)[3];
const bool inplace= (bool) ((int32_t *) dst->op_params)[4];
if (!inplace) {
ggml_cuda_cpy(ctx, src0, dst);
}
ggml_tensor dst_view = *dst;
dst_view.data = (void *)((char *)dst->data + offset);
dst_view.ne[0] = src1->ne[0];
dst_view.ne[1] = src1->ne[1];
dst_view.ne[2] = src1->ne[2];
dst_view.ne[3] = src1->ne[3];
dst_view.nb[0] = ggml_element_size(dst);
dst_view.nb[1] = nb1;
dst_view.nb[2] = nb2;
dst_view.nb[3] = nb3;
ggml_cuda_cpy(ctx, src1, &dst_view);
}

View File

@@ -0,0 +1,7 @@
#pragma once
#include "common.cuh"
#define CUDA_SET_BLOCK_SIZE 256
void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -211,7 +211,7 @@ static inline void hex_format_op_names(char * str, const struct ggml_tensor * t)
// ** backend sessions
struct ggml_hexagon_session {
ggml_hexagon_session(int dev_id) noexcept(false);
ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false);
~ggml_hexagon_session() noexcept(true);
void allocate(int dev_id) noexcept(false);
@@ -1631,10 +1631,13 @@ void ggml_hexagon_session::release() noexcept(true) {
}
}
ggml_hexagon_session::ggml_hexagon_session(int dev_id) noexcept(false) {
ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false) {
buffer_type.context = nullptr;
repack_buffer_type.context = nullptr;
buffer_type.device = dev;
repack_buffer_type.device = dev;
try {
allocate(dev_id);
@@ -3628,7 +3631,7 @@ ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) {
devices[i].iface = ggml_backend_hexagon_device_i;
devices[i].reg = reg;
try {
devices[i].context = new ggml_hexagon_session(i);
devices[i].context = new ggml_hexagon_session(i, &devices[i]);
} catch (std::exception const &exc) {
GGML_LOG_ERROR("ggml-hex: failed to create device/session %zu\n", i);
devices[i].context = nullptr;

View File

@@ -8,6 +8,7 @@
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstring>
#include <limits>
#include <map>
#include <stdexcept>
@@ -37,8 +38,15 @@ llama_kv_cache::llama_kv_cache(
const uint32_t n_layer_kv = hparams.n_layer_kv();
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
struct ggml_backend_buft_comparator {
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
}
};
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
// create a context for each buffer type
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
@@ -53,13 +61,12 @@ llama_kv_cache::llama_kv_cache(
return nullptr;
}
ctx_map[buft] = ctx;
ctxs.emplace_back(ctx);
ctx_map.emplace(buft, ctx);
return ctx;
}
return it->second;
return it->second.get();
};
GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
@@ -167,11 +174,8 @@ llama_kv_cache::llama_kv_cache(
}
// allocate tensors and initialize the buffers to avoid NaNs in the padding
for (auto it : ctx_map) {
auto * buft = it.first;
auto * ctx = it.second;
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
for (auto & [buft, ctx] : ctx_map) {
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
if (!buf) {
throw std::runtime_error("failed to allocate buffer for kv cache");
}
@@ -179,7 +183,7 @@ llama_kv_cache::llama_kv_cache(
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
ggml_backend_buffer_clear(buf, 0);
bufs.emplace_back(buf);
ctxs_bufs.emplace_back(std::move(ctx), buf);
}
{
@@ -203,7 +207,7 @@ void llama_kv_cache::clear(bool data) {
}
if (data) {
for (auto & buf : bufs) {
for (auto & [_, buf] : ctxs_bufs) {
ggml_backend_buffer_clear(buf.get(), 0);
}
}
@@ -472,8 +476,8 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> ret;
for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
for (const auto & [_, buf] : ctxs_bufs) {
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
}
return ret;
}
@@ -957,10 +961,14 @@ bool llama_kv_cache::get_has_shift() const {
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
uint32_t result = 0;
// pad the n_kv value so that the graph remains constant across batches and can be reused
// note: this also helps some backends with performance (f.ex https://github.com/ggml-org/llama.cpp/pull/16812#issuecomment-3455112220)
const uint32_t n_pad_cur = std::max(n_pad, 256u);
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
const auto & cells = v_cells[sinfo.strm[s]];
result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
result = std::max(std::min(cells.size(), std::max(n_pad_cur, GGML_PAD(cells.used_max_p1(), n_pad_cur))), result);
}
return result;
@@ -1298,7 +1306,7 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
size_t llama_kv_cache::total_size() const {
size_t size = 0;
for (const auto & buf : bufs) {
for (const auto & [_, buf] : ctxs_bufs) {
size += ggml_backend_buffer_get_size(buf.get());
}
@@ -2010,8 +2018,3 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub
void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_pos_bucket(dst, ubatch);
}
uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
// the FA kernels require padding to avoid extra runtime boundary checks
return cparams.flash_attn ? 256u : 32u;
}

View File

@@ -19,8 +19,6 @@ struct llama_context;
class llama_kv_cache : public llama_memory_i {
public:
static uint32_t get_padding(const llama_cparams & cparams);
struct stream_copy_info {
bool empty() const {
assert(ssrc.size() == sdst.size());
@@ -217,8 +215,8 @@ private:
// this is the SWA type of the cache - not to be confused with the model SWA type
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
// ggml contexts for the KV cache along with the allocated backend buffers:
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method

View File

@@ -7,6 +7,7 @@
#include <algorithm>
#include <cassert>
#include <cstring>
#include <limits>
#include <map>
#include <stdexcept>
@@ -32,8 +33,15 @@ llama_memory_recurrent::llama_memory_recurrent(
cells.clear();
cells.resize(mem_size);
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
struct ggml_backend_buft_comparator {
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
}
};
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
// create a context for each buffer type
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
@@ -48,13 +56,12 @@ llama_memory_recurrent::llama_memory_recurrent(
return nullptr;
}
ctx_map[buft] = ctx;
ctxs.emplace_back(ctx);
ctx_map.emplace(buft, ctx);
return ctx;
}
return it->second;
return it->second.get();
};
r_l.resize(n_layer);
@@ -93,17 +100,14 @@ llama_memory_recurrent::llama_memory_recurrent(
}
// allocate tensors and initialize the buffers to avoid NaNs in the padding
for (auto it : ctx_map) {
auto * buft = it.first;
auto * ctx = it.second;
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
for (auto & [buft, ctx] : ctx_map) {
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
if (!buf) {
throw std::runtime_error("failed to allocate buffer for rs cache");
}
ggml_backend_buffer_clear(buf, 0);
LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
bufs.emplace_back(buf);
ctxs_bufs.emplace_back(std::move(ctx), buf);
}
{
@@ -129,7 +133,7 @@ void llama_memory_recurrent::clear(bool data) {
used = 0;
if (data) {
for (auto & buf : bufs) {
for (auto & [_, buf] : ctxs_bufs) {
ggml_backend_buffer_clear(buf.get(), 0);
}
}
@@ -364,8 +368,8 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> ret;
for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
for (const auto & [_, buf] : ctxs_bufs) {
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
}
return ret;
}
@@ -662,7 +666,7 @@ bool llama_memory_recurrent::get_can_shift() const {
size_t llama_memory_recurrent::total_size() const {
size_t size = 0;
for (const auto & buf : bufs) {
for (const auto & [_, buf] : ctxs_bufs) {
size += ggml_backend_buffer_get_size(buf.get());
}

View File

@@ -109,8 +109,8 @@ private:
const uint32_t n_seq_max = 1;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
// ggml contexts for the KV cache along with the allocated backend buffers:
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
size_t total_size() const;

View File

@@ -2231,7 +2231,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
struct ggml_backend_buft_comparator {
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs);
return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
}
};
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
@@ -19641,7 +19641,7 @@ struct llm_build_apertus : public llm_graph_context {
}
};
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, const llama_cparams & cparams) const {
llama_memory_i * res;
switch (arch) {
@@ -19692,17 +19692,13 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
};
}
const auto padding = llama_kv_cache::get_padding(cparams);
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
res = new llama_memory_hybrid(
/* model */ *this,
/* attn_type_k */ params.type_k,
/* attn_type_v */ params.type_v,
/* attn_v_trans */ !cparams.flash_attn,
/* attn_kv_size */ cparams.n_ctx,
/* attn_n_pad */ padding,
/* attn_n_pad */ 1,
/* attn_n_swa */ hparams.n_swa,
/* attn_swa_type */ hparams.swa_type,
/* recurrent_type_k */ GGML_TYPE_F32,
@@ -19714,23 +19710,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* filter_attn */ std::move(filter_attn),
/* filter_recr */ std::move(filter_recr));
} else {
const auto padding = llama_kv_cache::get_padding(cparams);
uint32_t n_ctx_per_stream = cparams.n_ctx;
if (!cparams.kv_unified) {
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max;
} else {
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
cparams.n_ctx = n_ctx_per_stream;
}
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
llama_memory_i::layer_reuse_cb reuse = nullptr;
if (arch == LLM_ARCH_GEMMA3N) {
@@ -19757,7 +19742,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
n_ctx_per_stream,
cparams.n_seq_max,
cparams.n_ubatch,
padding,
1,
nullptr,
reuse);
} else {
@@ -19772,7 +19757,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.kv_unified,
n_ctx_per_stream,
cparams.n_seq_max,
padding,
1,
hparams.n_swa,
hparams.swa_type,
nullptr,

View File

@@ -500,9 +500,8 @@ struct llama_model {
ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const;
// note: can mutate `cparams`
// TODO: move this to new llm_arch_model_i interface
llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;
llama_memory_i * create_memory(const llama_memory_params & params, const llama_cparams & cparams) const;
// TODO: move this to new llm_arch_model_i interface
ggml_cgraph * build_graph(const llm_graph_params & params) const;

View File

@@ -1124,9 +1124,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
foo ::= "{" space foo-a-kv "}" space
foo-a-kv ::= "\"a\"" space ":" space string
root ::= foo
ref-definitions-foo ::= "{" space ref-definitions-foo-a-kv "}" space
ref-definitions-foo-a-kv ::= "\"a\"" space ":" space string
root ::= ref-definitions-foo
space ::= | " " | "\n"{1,2} [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
@@ -1151,20 +1151,58 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"type": "object"
})""",
R"""(
alternative-0 ::= foo
alternative-1 ::= bar
bar ::= "{" space (bar-b-kv )? "}" space
bar-b-kv ::= "\"b\"" space ":" space number
alternative-0 ::= ref-definitions-foo
alternative-1 ::= ref-definitions-bar
decimal-part ::= [0-9]{1,16}
foo ::= "{" space (foo-a-kv )? "}" space
foo-a-kv ::= "\"a\"" space ":" space number
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
ref-definitions-bar ::= "{" space (ref-definitions-bar-b-kv )? "}" space
ref-definitions-bar-b-kv ::= "\"b\"" space ":" space number
ref-definitions-foo ::= "{" space (ref-definitions-foo-a-kv )? "}" space
ref-definitions-foo-a-kv ::= "\"a\"" space ":" space number
root ::= alternative-0 | alternative-1
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
test({
SUCCESS,
"anyOf $ref",
R"""({
"properties": {
"a": {
"anyOf": [
{"type": "string"},
{"type": "number"}
]
},
"b": {
"anyOf": [
{"$ref": "#/properties/a/anyOf/0"},
{"type": "boolean"}
]
}
},
"type": "object"
})""",
R"""(
a ::= string | number
a-kv ::= "\"a\"" space ":" space a
a-rest ::= ( "," space b-kv )?
b ::= b-0 | boolean
b-0 ::= string
b-kv ::= "\"b\"" space ":" space b
boolean ::= ("true" | "false") space
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "{" space (a-kv a-rest | b-kv )? "}" space
space ::= | " " | "\n"{1,2} [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
test({
SUCCESS,
"mix of allOf, anyOf and $ref (similar to https://json.schemastore.org/tsconfig.json)",

View File

@@ -82,6 +82,9 @@ Using the `-d <n>` option, each test can be run at a specified context depth, pr
For a description of the other options, see the [main example](../main/README.md).
> [!NOTE]
> The measurements with `llama-bench` do not include the times for tokenization and for sampling.
## Examples
### Text generation with different models
@@ -131,7 +134,7 @@ $ ./llama-bench -n 0 -n 16 -p 64 -t 1,2,4,8,16,32
| llama 7B mostly Q4_0 | 3.56 GiB | 6.74 B | CPU | 16 | pp 64 | 33.52 ± 0.03 |
| llama 7B mostly Q4_0 | 3.56 GiB | 6.74 B | CPU | 16 | tg 16 | 15.32 ± 0.05 |
| llama 7B mostly Q4_0 | 3.56 GiB | 6.74 B | CPU | 32 | pp 64 | 59.00 ± 1.11 |
| llama 7B mostly Q4_0 | 3.56 GiB | 6.74 B | CPU | 32 | tg 16 | 16.41 ± 0.79 ||
| llama 7B mostly Q4_0 | 3.56 GiB | 6.74 B | CPU | 32 | tg 16 | 16.41 ± 0.79 |
### Different numbers of layers offloaded to the GPU

View File

@@ -345,10 +345,14 @@ export class SchemaConverter {
const selectors = ref.split('#')[1].split('/').slice(1);
for (const sel of selectors) {
if (!target || !(sel in target)) {
const selIndex = parseInt(sel, 10);
if (target && sel in target) {
target = target[sel];
} else if (target && selIndex in target) {
target = target[selIndex];
} else {
throw new Error(`Error resolving ref ${ref}: ${sel} not in ${JSON.stringify(target)}`);
}
target = target[sel];
}
this._refs[ref] = target;
@@ -594,7 +598,8 @@ export class SchemaConverter {
}
_resolveRef(ref) {
let refName = ref.split('/').pop();
let refFragment = ref.split('#').pop();
let refName = 'ref' + refFragment.replace(/[^a-zA-Z0-9-]+/g, '-');
if (!(refName in this._rules) && !this._refsBeingResolved.has(ref)) {
this._refsBeingResolved.add(ref);
const resolved = this._refs[ref];

View File

@@ -2866,10 +2866,12 @@ struct server_context {
// if context shifting is disabled, make sure that we don't run out of context
if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) {
slot.truncated = true;
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
SLT_DBG(slot, "stopped due to running out of context, n_past = %d, n_ctx = %d\n", slot.n_past, slot.n_ctx);
SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx);
}
// check the limits
@@ -2929,16 +2931,6 @@ struct server_context {
}
}
// if context shift is disabled, we stop when it reaches the context limit
if (slot.n_past >= slot.n_ctx) {
slot.truncated = true;
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx);
}
if (llama_vocab_is_eog(vocab, result.tok)) {
slot.stop = STOP_TYPE_EOS;
slot.has_next_token = false;
@@ -2946,19 +2938,6 @@ struct server_context {
SLT_DBG(slot, "%s", "stopped by EOS\n");
}
const auto n_ctx_train = llama_model_n_ctx_train(model);
if (slot.task->params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) {
slot.truncated = true;
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false; // stop prediction
SLT_WRN(slot,
"n_predict (%d) is set for infinite generation. "
"Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
slot.task->params.n_predict, n_ctx_train);
}
SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
return slot.has_next_token; // continue

View File

@@ -45,7 +45,7 @@ def test_ctx_shift_enabled():
@pytest.mark.parametrize("n_predict,n_token_output,truncated", [
(64, 64, False),
(-1, 120, True),
(-1, 248, True), # 8 tokens prompt + 248 tokens generated = 256 tokens total
])
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
global server