Compare commits

..

10 Commits

Author SHA1 Message Date
Georgi Gerganov
d86e23101e server : minor log updates
ggml-ci
2025-02-08 16:23:37 +02:00
Johannes Gäßler
d80be897ac CUDA: fix min. version for movmatrix (#11751) 2025-02-08 10:46:07 +01:00
Nikolaos Pothitos
3ab410f55f readme : update front-end framework (#11753)
After the migration to React with #11688
2025-02-08 10:43:04 +01:00
Xuan-Son Nguyen
0cf867160c server : (webui) fix numeric settings being saved as string (#11739)
* server : (webui) fix numeric settings being saved as string

* add some more comments
2025-02-08 10:42:34 +01:00
Eric Curtin
d2fe216fb2 Make logging more verbose (#11714)
Debugged an issue with a user who was on a read-only filesystem.

Signed-off-by: Eric Curtin <ecurtin@redhat.com>
2025-02-07 14:42:46 +00:00
Georgi Gerganov
ed926d8833 llama : fix defrag logic (#11707)
* llama : fix defrag logic

ggml-ci

* cont : better logic

ggml-ci

* cont : clamp fragmentation to 0.0

ggml-ci
2025-02-07 16:05:34 +02:00
Christian Fillion
2d219b389e vocab : ignore invalid UTF-8 input in the BPE tokenizer (#11729)
Silently insert U+FFFD(s) (Unicode replacement character) instead until the
next valid codepoint can be found.

This fixes `llama_tokenize` throwing an exception across the C API boundary
or libllama's module boundary (the caller's runtime might be incompatible!)

Returing a proper error code might be desirable, however the signature
of `llama_tokenize` doesn't allow it as all return values already have
existing meaning.
2025-02-07 15:55:47 +02:00
magicse
333820d749 llama : fix progress dots (#11730)
* Update llama.cpp

For display progress dots in terminal.
Without this it didn't display dots progress during loading model from file.

* Update llama.cpp

removed trailing spaces
2025-02-07 15:48:47 +02:00
Jeff Bolz
c026ba3c23 vulkan: print shared memory size (#11719) 2025-02-07 11:26:03 +01:00
Christian Fillion
7ee953a64a llama : add llama_sampler_init for safe usage of llama_sampler_free (#11727)
The C API in llama.h claims users can implement `llama_sampler_i` to
create custom `llama_sampler`. The sampler chain takes ownership and
calls `llama_sampler_free` on them. However, `llama_sampler_free` is
hard-coded to use `delete`. This is undefined behavior if the object
wasn't also allocated via `new` from libllama's C++ runtime. Callers
in C and C-compatible languages do not use C++'s `new` operator. C++
callers may not be sharing the same heap as libllama.
2025-02-07 11:33:27 +02:00
13 changed files with 162 additions and 125 deletions

View File

@@ -254,10 +254,10 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g
};
}
return new llama_sampler{
return llama_sampler_init(
/* .iface = */ &llama_sampler_llg_i,
/* .ctx = */ ctx,
};
/* .ctx = */ ctx
);
}
#else

View File

@@ -346,7 +346,7 @@ class HttpClient {
if (!output_file.empty()) {
output_file_partial = output_file + ".partial";
if (!out.open(output_file_partial, "ab")) {
printe("Failed to open file\n");
printe("Failed to open file for writing\n");
return 1;
}

View File

@@ -220,7 +220,7 @@ services:
The project includes a web-based user interface that enables interaction with the model through the `/chat/completions` endpoint.
The web UI is developed using:
- `vue` framework for frontend development
- `react` framework for frontend development
- `tailwindcss` and `daisyui` for styling
- `vite` for build tooling

Binary file not shown.

View File

@@ -334,24 +334,24 @@ struct server_task {
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
LOG_DBG("JSON schema: %s\n", schema.dump(2).c_str());
SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str());
params.sampling.grammar = json_schema_to_grammar(schema);
LOG_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
} catch (const std::exception & e) {
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
}
} else {
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
LOG_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
LOG_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
}
{
auto it = data.find("chat_format");
if (it != data.end()) {
params.oaicompat_chat_format = static_cast<common_chat_format>(it->get<int>());
LOG_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
} else {
params.oaicompat_chat_format = defaults.oaicompat_chat_format;
}
@@ -367,12 +367,12 @@ struct server_task {
auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
SRV_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
continue;
}
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
SRV_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
params.sampling.grammar_trigger_words.push_back(trigger);
}
}
@@ -381,11 +381,11 @@ struct server_task {
for (const auto & t : *preserved_tokens) {
auto ids = common_tokenize(vocab, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
LOG_DBG("Preserved token: %d\n", ids[0]);
SRV_DBG("Preserved token: %d\n", ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
} else {
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
SRV_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
}
}
}
@@ -717,7 +717,7 @@ struct server_task_result_cmpl_final : server_task_result {
std::string finish_reason = "length";
common_chat_msg msg;
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
LOG_DBG("Parsing chat message: %s\n", content.c_str());
SRV_DBG("Parsing chat message: %s\n", content.c_str());
msg = common_chat_parse(content, oaicompat_chat_format);
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
} else {
@@ -1885,7 +1885,7 @@ struct server_context {
}
if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) {
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
chat_templates = common_chat_templates_from_model(model, "chatml");
} else {
chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
@@ -3355,10 +3355,10 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
// reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
LOG_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
LOG_DBG("request: %s\n", req.body.c_str());
LOG_DBG("response: %s\n", res.body.c_str());
SRV_DBG("request: %s\n", req.body.c_str());
SRV_DBG("response: %s\n", res.body.c_str());
}
std::function<void(int)> shutdown_handler;
@@ -3860,7 +3860,9 @@ int main(int argc, char ** argv) {
try {
const auto & prompt = data.at("prompt");
LOG_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
// TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {

View File

@@ -23,6 +23,7 @@ export default function MarkdownDisplay({ content }: { content: string }) {
button: (props) => (
<CopyCodeButton {...props} origContent={preprocessedContent} />
),
// note: do not use "pre", "p" or other basic html elements here, it will cause the node to re-render when the message is being generated (this should be a bug with react-markdown, not sure how to fix it)
}}
>
{preprocessedContent}

View File

@@ -3,6 +3,7 @@ import { useAppContext } from '../utils/app.context';
import { CONFIG_DEFAULT, CONFIG_INFO } from '../Config';
import { isDev } from '../Config';
import StorageUtils from '../utils/storage';
import { isBoolean, isNumeric, isString } from '../utils/misc';
type SettKey = keyof typeof CONFIG_DEFAULT;
@@ -52,7 +53,42 @@ export default function SettingDialog({
};
const handleSave = () => {
saveConfig(localConfig);
// copy the local config to prevent direct mutation
const newConfig: typeof CONFIG_DEFAULT = JSON.parse(
JSON.stringify(localConfig)
);
// validate the config
for (const key in newConfig) {
const value = newConfig[key as SettKey];
const mustBeBoolean = isBoolean(CONFIG_DEFAULT[key as SettKey]);
const mustBeString = isString(CONFIG_DEFAULT[key as SettKey]);
const mustBeNumeric = isNumeric(CONFIG_DEFAULT[key as SettKey]);
if (mustBeString) {
if (!isString(value)) {
alert(`Value for ${key} must be string`);
return;
}
} else if (mustBeNumeric) {
const trimedValue = value.toString().trim();
const numVal = Number(trimedValue);
if (isNaN(numVal) || !isNumeric(numVal) || trimedValue.length === 0) {
alert(`Value for ${key} must be numeric`);
return;
}
// force conversion to number
// @ts-expect-error this is safe
newConfig[key] = numVal;
} else if (mustBeBoolean) {
if (!isBoolean(value)) {
alert(`Value for ${key} must be boolean`);
return;
}
} else {
console.error(`Unknown default type for key ${key}`);
}
}
if (isDev) console.log('Saving config', newConfig);
saveConfig(newConfig);
onClose();
};
@@ -66,6 +102,11 @@ export default function SettingDialog({
onClose();
};
const onChange = (key: SettKey) => (value: string | boolean) => {
// note: we do not perform validation here, because we may get incomplete value as user is still typing it
setLocalConfig({ ...localConfig, [key]: value });
};
return (
<dialog className={`modal ${show ? 'modal-open' : ''}`}>
<div className="modal-box">
@@ -79,9 +120,7 @@ export default function SettingDialog({
configKey="apiKey"
configDefault={CONFIG_DEFAULT}
value={localConfig.apiKey}
onChange={(value) =>
setLocalConfig({ ...localConfig, apiKey: value })
}
onChange={onChange('apiKey')}
/>
<label className="form-control mb-2">
@@ -92,12 +131,7 @@ export default function SettingDialog({
className="textarea textarea-bordered h-24"
placeholder={`Default: ${CONFIG_DEFAULT.systemMessage}`}
value={localConfig.systemMessage}
onChange={(e) =>
setLocalConfig({
...localConfig,
systemMessage: e.target.value,
})
}
onChange={(e) => onChange('systemMessage')(e.target.value)}
/>
</label>
@@ -107,9 +141,7 @@ export default function SettingDialog({
configKey={key}
configDefault={CONFIG_DEFAULT}
value={localConfig[key]}
onChange={(value) =>
setLocalConfig({ ...localConfig, [key]: value })
}
onChange={onChange(key)}
/>
))}
@@ -123,9 +155,7 @@ export default function SettingDialog({
configKey="samplers"
configDefault={CONFIG_DEFAULT}
value={localConfig.samplers}
onChange={(value) =>
setLocalConfig({ ...localConfig, samplers: value })
}
onChange={onChange('samplers')}
/>
{OTHER_SAMPLER_KEYS.map((key) => (
<SettingsModalShortInput
@@ -133,9 +163,7 @@ export default function SettingDialog({
configKey={key}
configDefault={CONFIG_DEFAULT}
value={localConfig[key]}
onChange={(value) =>
setLocalConfig({ ...localConfig, [key]: value })
}
onChange={onChange(key)}
/>
))}
</div>
@@ -152,9 +180,7 @@ export default function SettingDialog({
configKey={key}
configDefault={CONFIG_DEFAULT}
value={localConfig[key]}
onChange={(value) =>
setLocalConfig({ ...localConfig, [key]: value })
}
onChange={onChange(key)}
/>
))}
</div>
@@ -171,10 +197,7 @@ export default function SettingDialog({
className="checkbox"
checked={localConfig.showThoughtInProgress}
onChange={(e) =>
setLocalConfig({
...localConfig,
showThoughtInProgress: e.target.checked,
})
onChange('showThoughtInProgress')(e.target.checked)
}
/>
<span className="ml-4">
@@ -187,10 +210,7 @@ export default function SettingDialog({
className="checkbox"
checked={localConfig.excludeThoughtOnReq}
onChange={(e) =>
setLocalConfig({
...localConfig,
excludeThoughtOnReq: e.target.checked,
})
onChange('excludeThoughtOnReq')(e.target.checked)
}
/>
<span className="ml-4">
@@ -220,10 +240,7 @@ export default function SettingDialog({
className="checkbox"
checked={localConfig.showTokensPerSecond}
onChange={(e) =>
setLocalConfig({
...localConfig,
showTokensPerSecond: e.target.checked,
})
onChange('showTokensPerSecond')(e.target.checked)
}
/>
<span className="ml-4">Show tokens per second</span>
@@ -245,9 +262,7 @@ export default function SettingDialog({
className="textarea textarea-bordered h-24"
placeholder='Example: { "mirostat": 1, "min_p": 0.1 }'
value={localConfig.custom}
onChange={(e) =>
setLocalConfig({ ...localConfig, custom: e.target.value })
}
onChange={(e) => onChange('custom')(e.target.value)}
/>
</label>
</div>

View File

@@ -16,7 +16,7 @@
#include "common.cuh"
#if CUDART_VERSION >= 11800
#if CUDART_VERSION >= 11080
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
int ret = 0;
@@ -50,7 +50,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
return ret_low | ret_high;
}
#endif // CUDART_VERSION >= 11800
#endif // CUDART_VERSION >= 11080
template <typename T>

View File

@@ -2780,8 +2780,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
std::string device_name = props2.properties.deviceName.data();
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %s\n",
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, matrix_cores.c_str());
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n",
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str());
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");

View File

@@ -1114,11 +1114,12 @@ extern "C" {
};
struct llama_sampler {
struct llama_sampler_i * iface;
llama_sampler_context_t ctx;
const struct llama_sampler_i * iface;
llama_sampler_context_t ctx;
};
// mirror of llama_sampler_i:
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);

View File

@@ -316,6 +316,13 @@ static uint32_t get_rng_seed(uint32_t seed) {
// llama_sampler API
struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
return new llama_sampler {
/* .iface = */ iface,
/* .ctx = */ ctx,
};
}
const char * llama_sampler_name(const struct llama_sampler * smpl) {
if (!smpl->iface) {
return "(null)";
@@ -347,10 +354,10 @@ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
}
if (smpl->ctx == nullptr) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ smpl->iface,
/* .ctx = */ nullptr,
};
/* .ctx = */ nullptr
);
}
GGML_ABORT("the sampler does not support cloning");
@@ -472,15 +479,15 @@ static struct llama_sampler_i llama_sampler_chain_i = {
};
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_chain_i,
/* .ctx = */ new llama_sampler_chain {
/* .params = */ params,
/* .samplers = */ {},
/* .t_sample_us = */ 0,
/* .n_sample = */ 0,
},
};
}
);
}
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
@@ -546,10 +553,10 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
};
struct llama_sampler * llama_sampler_init_greedy() {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_greedy_i,
/* .ctx = */ nullptr,
};
/* .ctx = */ nullptr
);
}
// dist
@@ -608,14 +615,14 @@ static struct llama_sampler_i llama_sampler_dist_i = {
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_dist_i,
/* .ctx = */ new llama_sampler_dist {
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
);
}
// softmax
@@ -638,10 +645,10 @@ static struct llama_sampler_i llama_sampler_softmax_i = {
};
struct llama_sampler * llama_sampler_init_softmax() {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_softmax_i,
/* .ctx = */ nullptr,
};
/* .ctx = */ nullptr
);
}
// top-k
@@ -678,12 +685,12 @@ static struct llama_sampler_i llama_sampler_top_k_i = {
};
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_top_k_i,
/* .ctx = */ new llama_sampler_top_k {
/* .k = */ k,
},
};
}
);
}
// top-p
@@ -744,13 +751,13 @@ static struct llama_sampler_i llama_sampler_top_p_i = {
};
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_top_p_i,
/* .ctx = */ new llama_sampler_top_p {
/* .p = */ p,
/* .min_keep = */ min_keep,
},
};
}
);
}
// min-p
@@ -840,13 +847,13 @@ static struct llama_sampler_i llama_sampler_min_p_i = {
};
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_min_p_i,
/* .ctx = */ new llama_sampler_min_p {
/* .p = */ p,
/* .min_keep = */ min_keep,
},
};
}
);
}
// typical
@@ -939,13 +946,13 @@ static struct llama_sampler_i llama_sampler_typical_i = {
};
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_typical_i,
/* .ctx = */ new llama_sampler_typical {
/* .p = */ p,
/* .min_keep = */ min_keep,
},
};
}
);
}
// temp
@@ -983,12 +990,12 @@ static struct llama_sampler_i llama_sampler_temp_i = {
};
struct llama_sampler * llama_sampler_init_temp(float temp) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_temp_i,
/* .ctx = */ new llama_sampler_temp {
/*.temp = */ temp,
},
};
}
);
}
// temp-ext
@@ -1093,14 +1100,14 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
};
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_temp_ext_i,
/* .ctx = */ new llama_sampler_temp_ext {
/* .temp = */ temp,
/* .delta = */ delta,
/* .exponent = */ exponent,
},
};
}
);
}
// xtc
@@ -1185,7 +1192,7 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_xtc_i,
/* .ctx = */ new llama_sampler_xtc {
/* .probability = */ p,
@@ -1194,8 +1201,8 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
);
}
// mirostat
@@ -1292,7 +1299,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_mirostat_i,
/* .ctx = */ new llama_sampler_mirostat {
/* .n_vocab = */ n_vocab,
@@ -1303,8 +1310,8 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
/* .m = */ m,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
);
}
// mirostat v2
@@ -1391,7 +1398,7 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_mirostat_v2_i,
/* .ctx = */ new llama_sampler_mirostat_v2 {
/* .seed = */ seed,
@@ -1400,8 +1407,8 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau,
/* .eta = */ eta,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
);
}
// grammar
@@ -1528,10 +1535,10 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
};
}
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_grammar_i,
/* .ctx = */ ctx,
};
/* .ctx = */ ctx
);
}
struct llama_sampler * llama_sampler_init_grammar(
@@ -1678,7 +1685,7 @@ struct llama_sampler * llama_sampler_init_penalties(
float penalty_present) {
penalty_last_n = std::max(penalty_last_n, 0);
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_penalties_i,
/* .ctx = */ new llama_sampler_penalties {
/* .penalty_last_n = */ penalty_last_n,
@@ -1687,8 +1694,8 @@ struct llama_sampler * llama_sampler_init_penalties(
/* .penalty_present = */ penalty_present,
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
/* .token_count = */ {},
},
};
}
);
}
// DRY
@@ -2041,7 +2048,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
}
}
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_dry_i,
/* .ctx = */ new llama_sampler_dry {
/* .total_context_size = */ context_size,
@@ -2053,8 +2060,8 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
/* .dry_max_token_repeat = */ {},
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
},
};
}
);
}
// wrapper for test-sampling.cpp
@@ -2155,14 +2162,14 @@ struct llama_sampler * llama_sampler_init_logit_bias(
int32_t n_vocab,
int32_t n_logit_bias,
const llama_logit_bias * logit_bias) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_logit_bias_i,
/* .ctx = */ new llama_sampler_logit_bias {
/* .n_vocab = */ n_vocab,
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
/* .to_search = */ {},
},
};
}
);
}
// infill
@@ -2377,14 +2384,14 @@ static struct llama_sampler_i llama_sampler_infill_i = {
};
struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_infill_i,
/* .ctx = */ new llama_sampler_infill {
/* .vocab = */ vocab,
/* .buf0 = */ std::vector<char>(512),
/* .buf1 = */ std::vector<char>(512),
},
};
}
);
}
// utils

View File

@@ -8801,12 +8801,14 @@ static int llama_decode_impl(
//llama_synchronize(&lctx);
// decide if we need to defrag the kv cache
if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
// - do not defrag small contexts (i.e. < 2048 tokens)
// - count the padding towards the number of used tokens
const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + llama_kv_cache_get_padding(cparams))/float(kv_self.n)) : 0.0f;
// queue defragmentation for next llama_kv_cache_update
if (fragmentation > cparams.defrag_thold) {
//LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
llama_kv_cache_defrag(kv_self);
}
@@ -9428,7 +9430,6 @@ static struct llama_model * llama_model_load_from_file_impl(
struct llama_model_params params) {
ggml_time_init();
llama_model * model = new llama_model(params);
unsigned cur_percentage = 0;
if (params.progress_callback == NULL) {
@@ -9447,6 +9448,8 @@ static struct llama_model * llama_model_load_from_file_impl(
};
}
llama_model * model = new llama_model(params);
// create list of devices to use with this model
if (params.devices) {
for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) {

View File

@@ -618,7 +618,14 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
result.reserve(utf8.size());
size_t offset = 0;
while (offset < utf8.size()) {
result.push_back(unicode_cpt_from_utf8(utf8, offset));
try {
result.push_back(unicode_cpt_from_utf8(utf8, offset));
}
catch (const std::invalid_argument & /*ex*/) {
// Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize
++offset;
result.emplace_back(0xFFFD); // replacement character
}
}
return result;
}