mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-12 14:03:20 +02:00
Compare commits
50 Commits
gg/server-
...
b4720
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fc1b0d0936 | ||
|
|
89daa2564f | ||
|
|
300907b211 | ||
|
|
94b87f87b5 | ||
|
|
dbc2ec59b5 | ||
|
|
3d68f034da | ||
|
|
38e32eb6a0 | ||
|
|
a4f011e8d0 | ||
|
|
a7b8ce2260 | ||
|
|
04045bb842 | ||
|
|
8a8c4ceb60 | ||
|
|
c1f958c038 | ||
|
|
c48f630d1c | ||
|
|
bd6e55bfd3 | ||
|
|
c7f460ab88 | ||
|
|
27e8a23300 | ||
|
|
e4376270d9 | ||
|
|
3e69319772 | ||
|
|
a394039db0 | ||
|
|
be3bbd6215 | ||
|
|
31afcbee0e | ||
|
|
5c4284d57b | ||
|
|
bfd11a2344 | ||
|
|
0fb77f821f | ||
|
|
e598697d63 | ||
|
|
fef0cbeadf | ||
|
|
748ee9fe93 | ||
|
|
198b1ec611 | ||
|
|
c3d6af7cd2 | ||
|
|
369be5598a | ||
|
|
4078c77f98 | ||
|
|
90e4dba461 | ||
|
|
a18f481f99 | ||
|
|
b9ab0a4d0b | ||
|
|
7b891bdc86 | ||
|
|
81732619fd | ||
|
|
507f9174fe | ||
|
|
19b392d58d | ||
|
|
0893e0114e | ||
|
|
d7b31a9d84 | ||
|
|
9ac3457b39 | ||
|
|
c2a67efe38 | ||
|
|
b044a0fe3c | ||
|
|
19d3c8293b | ||
|
|
98f6b0fd1e | ||
|
|
55ac8c7791 | ||
|
|
e6e6583199 | ||
|
|
aaa5505307 | ||
|
|
bdcf8b6a56 | ||
|
|
4d3465c5ae |
@@ -1,6 +1,6 @@
|
||||
ARG UBUNTU_VERSION=22.04
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG CUDA_VERSION=12.6.0
|
||||
ARG CUDA_VERSION=12.4.0
|
||||
# Target the CUDA build image
|
||||
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
ARG UBUNTU_VERSION=22.04
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG MUSA_VERSION=rc3.1.0
|
||||
ARG MUSA_VERSION=rc3.1.1
|
||||
# Target the MUSA build image
|
||||
ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
|
||||
32
.github/workflows/build.yml
vendored
32
.github/workflows/build.yml
vendored
@@ -401,7 +401,35 @@ jobs:
|
||||
run: |
|
||||
cd build
|
||||
# This is using llvmpipe and runs slower than other backends
|
||||
ctest -L main --verbose --timeout 1800
|
||||
ctest -L main --verbose --timeout 2700
|
||||
|
||||
- name: Determine tag name
|
||||
id: tag
|
||||
shell: bash
|
||||
run: |
|
||||
BUILD_NUMBER="$(git rev-list --count HEAD)"
|
||||
SHORT_HASH="$(git rev-parse --short=7 HEAD)"
|
||||
if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then
|
||||
echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT
|
||||
else
|
||||
SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-')
|
||||
echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Pack artifacts
|
||||
id: pack_artifacts
|
||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||
run: |
|
||||
cp LICENSE ./build/bin/
|
||||
cp examples/run/linenoise.cpp/LICENSE ./build/bin/LICENSE.linenoise.cpp
|
||||
zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip ./build/bin/*
|
||||
|
||||
- name: Upload artifacts
|
||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip
|
||||
name: llama-bin-ubuntu-vulkan-x64.zip
|
||||
|
||||
ubuntu-22-cmake-hip:
|
||||
runs-on: ubuntu-22.04
|
||||
@@ -443,7 +471,7 @@ jobs:
|
||||
|
||||
ubuntu-22-cmake-musa:
|
||||
runs-on: ubuntu-22.04
|
||||
container: mthreads/musa:rc3.1.0-devel-ubuntu22.04
|
||||
container: mthreads/musa:rc3.1.1-devel-ubuntu22.04
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
||||
16
README.md
16
README.md
@@ -235,6 +235,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
| [HIP](docs/build.md#hip) | AMD GPU |
|
||||
| [Vulkan](docs/build.md#vulkan) | GPU |
|
||||
| [CANN](docs/build.md#cann) | Ascend NPU |
|
||||
| [OpenCL](docs/backend/OPENCL.md) | Adreno GPU |
|
||||
|
||||
## Building the project
|
||||
|
||||
@@ -518,5 +519,18 @@ If your issue is with model generation quality, then please at least scan the fo
|
||||
- [Aligning language models to follow instructions](https://openai.com/research/instruction-following)
|
||||
- [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155)
|
||||
|
||||
#### References
|
||||
## Completions
|
||||
Command-line completion is available for some environments.
|
||||
|
||||
#### Bash Completion
|
||||
```bash
|
||||
$ build/bin/llama-cli --completion-bash > ~/.llama-completion.bash
|
||||
$ source ~/.llama-completion.bash
|
||||
```
|
||||
Optionally this can be added to your `.bashrc` or `.bash_profile` to load it
|
||||
automatically. For example:
|
||||
```console
|
||||
$ echo "source ~/.llama-completion.bash" >> ~/.bashrc
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
@@ -96,6 +96,22 @@ if (LLAMA_LLGUIDANCE)
|
||||
include(ExternalProject)
|
||||
set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source)
|
||||
set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release)
|
||||
|
||||
# Set the correct library file extension based on platform
|
||||
if (WIN32)
|
||||
set(LLGUIDANCE_LIB_NAME "llguidance.lib")
|
||||
# Add Windows-specific libraries
|
||||
set(LLGUIDANCE_PLATFORM_LIBS
|
||||
ws2_32 # Windows Sockets API
|
||||
userenv # For GetUserProfileDirectoryW
|
||||
ntdll # For NT functions
|
||||
bcrypt # For BCryptGenRandom
|
||||
)
|
||||
else()
|
||||
set(LLGUIDANCE_LIB_NAME "libllguidance.a")
|
||||
set(LLGUIDANCE_PLATFORM_LIBS "")
|
||||
endif()
|
||||
|
||||
ExternalProject_Add(llguidance_ext
|
||||
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
|
||||
# v0.6.12:
|
||||
@@ -106,17 +122,18 @@ if (LLAMA_LLGUIDANCE)
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND cargo build --release
|
||||
INSTALL_COMMAND ""
|
||||
BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/libllguidance.a ${LLGUIDANCE_PATH}/llguidance.h
|
||||
BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/${LLGUIDANCE_LIB_NAME} ${LLGUIDANCE_PATH}/llguidance.h
|
||||
UPDATE_COMMAND ""
|
||||
)
|
||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_LLGUIDANCE)
|
||||
|
||||
add_library(llguidance STATIC IMPORTED)
|
||||
set_target_properties(llguidance PROPERTIES IMPORTED_LOCATION ${LLGUIDANCE_PATH}/libllguidance.a)
|
||||
set_target_properties(llguidance PROPERTIES IMPORTED_LOCATION ${LLGUIDANCE_PATH}/${LLGUIDANCE_LIB_NAME})
|
||||
add_dependencies(llguidance llguidance_ext)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH})
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance)
|
||||
# Add platform libraries to the main target
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance ${LLGUIDANCE_PLATFORM_LIBS})
|
||||
endif ()
|
||||
|
||||
target_include_directories(${TARGET} PUBLIC .)
|
||||
|
||||
137
common/arg.cpp
137
common/arg.cpp
@@ -365,6 +365,112 @@ static void common_params_print_usage(common_params_context & ctx_arg) {
|
||||
print_options(specific_options);
|
||||
}
|
||||
|
||||
static void common_params_print_completion(common_params_context & ctx_arg) {
|
||||
std::vector<common_arg *> common_options;
|
||||
std::vector<common_arg *> sparam_options;
|
||||
std::vector<common_arg *> specific_options;
|
||||
|
||||
for (auto & opt : ctx_arg.options) {
|
||||
if (opt.is_sparam) {
|
||||
sparam_options.push_back(&opt);
|
||||
} else if (opt.in_example(ctx_arg.ex)) {
|
||||
specific_options.push_back(&opt);
|
||||
} else {
|
||||
common_options.push_back(&opt);
|
||||
}
|
||||
}
|
||||
|
||||
printf("_llama_completions() {\n");
|
||||
printf(" local cur prev opts\n");
|
||||
printf(" COMPREPLY=()\n");
|
||||
printf(" cur=\"${COMP_WORDS[COMP_CWORD]}\"\n");
|
||||
printf(" prev=\"${COMP_WORDS[COMP_CWORD-1]}\"\n\n");
|
||||
|
||||
printf(" opts=\"");
|
||||
auto print_options = [](const std::vector<common_arg *> & options) {
|
||||
for (const common_arg * opt : options) {
|
||||
for (const char * arg : opt->args) {
|
||||
printf("%s ", arg);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
print_options(common_options);
|
||||
print_options(sparam_options);
|
||||
print_options(specific_options);
|
||||
printf("\"\n\n");
|
||||
|
||||
printf(" case \"$prev\" in\n");
|
||||
printf(" --model)\n");
|
||||
printf(" COMPREPLY=( $(compgen -f -X '!*.gguf' -- \"$cur\") $(compgen -d -- \"$cur\") )\n");
|
||||
printf(" return 0\n");
|
||||
printf(" ;;\n");
|
||||
printf(" --grammar-file)\n");
|
||||
printf(" COMPREPLY=( $(compgen -f -X '!*.gbnf' -- \"$cur\") $(compgen -d -- \"$cur\") )\n");
|
||||
printf(" return 0\n");
|
||||
printf(" ;;\n");
|
||||
printf(" --chat-template-file)\n");
|
||||
printf(" COMPREPLY=( $(compgen -f -X '!*.jinja' -- \"$cur\") $(compgen -d -- \"$cur\") )\n");
|
||||
printf(" return 0\n");
|
||||
printf(" ;;\n");
|
||||
printf(" *)\n");
|
||||
printf(" COMPREPLY=( $(compgen -W \"${opts}\" -- \"$cur\") )\n");
|
||||
printf(" return 0\n");
|
||||
printf(" ;;\n");
|
||||
printf(" esac\n");
|
||||
printf("}\n\n");
|
||||
|
||||
std::set<std::string> executables = {
|
||||
"llama-batched",
|
||||
"llama-batched-bench",
|
||||
"llama-bench",
|
||||
"llama-cli",
|
||||
"llama-convert-llama2c-to-ggml",
|
||||
"llama-cvector-generator",
|
||||
"llama-embedding",
|
||||
"llama-eval-callback",
|
||||
"llama-export-lora",
|
||||
"llama-gbnf-validator",
|
||||
"llama-gen-docs",
|
||||
"llama-gguf",
|
||||
"llama-gguf-hash",
|
||||
"llama-gguf-split",
|
||||
"llama-gritlm",
|
||||
"llama-imatrix",
|
||||
"llama-infill",
|
||||
"llama-llava-cli",
|
||||
"llama-llava-clip-quantize-cli",
|
||||
"llama-lookahead",
|
||||
"llama-lookup",
|
||||
"llama-lookup-create",
|
||||
"llama-lookup-merge",
|
||||
"llama-lookup-stats",
|
||||
"llama-minicpmv-cli",
|
||||
"llama-parallel",
|
||||
"llama-passkey",
|
||||
"llama-perplexity",
|
||||
"llama-q8dot",
|
||||
"llama-quantize",
|
||||
"llama-quantize-stats",
|
||||
"llama-qwen2vl-cli",
|
||||
"llama-retrieval",
|
||||
"llama-run",
|
||||
"llama-save-load-state",
|
||||
"llama-server",
|
||||
"llama-simple",
|
||||
"llama-simple-chat",
|
||||
"llama-speculative",
|
||||
"llama-speculative-simple",
|
||||
"llama-tokenize",
|
||||
"llama-tts",
|
||||
"llama-vdot"
|
||||
};
|
||||
|
||||
for (const auto& exe : executables) {
|
||||
printf("complete -F _llama_completions %s\n", exe.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & value) {
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
auto dev_names = string_split<std::string>(value, ',');
|
||||
@@ -426,6 +532,10 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
|
||||
}
|
||||
exit(0);
|
||||
}
|
||||
if (ctx_arg.params.completion) {
|
||||
common_params_print_completion(ctx_arg);
|
||||
exit(0);
|
||||
}
|
||||
} catch (const std::invalid_argument & ex) {
|
||||
fprintf(stderr, "%s\n", ex.what());
|
||||
ctx_arg.params = params_org;
|
||||
@@ -494,6 +604,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
exit(0);
|
||||
}
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"--completion-bash"},
|
||||
"print source-able bash completion script for llama.cpp",
|
||||
[](common_params & params) {
|
||||
params.completion = true;
|
||||
}
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"--verbose-prompt"},
|
||||
string_format("print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false"),
|
||||
@@ -674,7 +791,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"--no-context-shift"},
|
||||
string_format("disables context shift on inifinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
|
||||
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
|
||||
[](common_params & params) {
|
||||
params.ctx_shift = false;
|
||||
}
|
||||
@@ -946,6 +1063,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.sampling.min_p = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--top-nsigma"}, "N",
|
||||
string_format("top-n-sigma sampling (default: %.1f, -1.0 = disabled)", params.sampling.top_n_sigma),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.top_n_sigma = std::stof(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN}).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--xtc-probability"}, "N",
|
||||
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
|
||||
@@ -1975,6 +2099,17 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.use_jinja = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA"));
|
||||
add_opt(common_arg(
|
||||
{"--reasoning-format"}, "FORMAT",
|
||||
"reasoning format (default: deepseek; allowed values: deepseek, none)\n"
|
||||
"controls whether thought tags are extracted from the response, and in which format they're returned. 'none' leaves thoughts unparsed in `message.content`, 'deepseek' puts them in `message.reasoning_content` (for DeepSeek R1 & Command R7B only).\n"
|
||||
"only supported for non-streamed responses",
|
||||
[](common_params & params, const std::string & value) {
|
||||
/**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; }
|
||||
else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; }
|
||||
else { std::invalid_argument("invalid value"); }
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK"));
|
||||
add_opt(common_arg(
|
||||
{"--chat-template"}, "JINJA_TEMPLATE",
|
||||
string_format(
|
||||
|
||||
@@ -249,16 +249,30 @@ class chat_template {
|
||||
inputs.add_generation_prompt = false;
|
||||
full = apply(inputs);
|
||||
}
|
||||
|
||||
if (full.find(prefix) != 0) {
|
||||
if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) {
|
||||
prefix = prefix.substr(0, prefix.size() - eos_token_.size());
|
||||
auto eos_pos_last = full.rfind(eos_token_);
|
||||
if (eos_pos_last == prefix.size() - eos_token_.size() ||
|
||||
(full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
|
||||
full = full.substr(0, eos_pos_last);
|
||||
}
|
||||
size_t common_prefix_length = 0;
|
||||
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
|
||||
if (prefix[i] != full[i]) {
|
||||
break;
|
||||
}
|
||||
if (prefix[i] == '<') {
|
||||
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
|
||||
// but it removes thinking tags for past messages.
|
||||
// The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
|
||||
continue;
|
||||
}
|
||||
common_prefix_length = i + 1;
|
||||
}
|
||||
if (full.find(prefix) != 0) {
|
||||
auto example = full.substr(common_prefix_length);
|
||||
if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
|
||||
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
|
||||
} else {
|
||||
tool_call_example_ = example;
|
||||
}
|
||||
tool_call_example_ = full.substr(prefix.size());
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
|
||||
@@ -363,7 +377,7 @@ class chat_template {
|
||||
if (polyfill_tools) {
|
||||
adjusted_messages = add_system(inputs.messages,
|
||||
"You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
|
||||
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_));
|
||||
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
|
||||
} else {
|
||||
adjusted_messages = inputs.messages;
|
||||
}
|
||||
|
||||
319
common/chat.cpp
319
common/chat.cpp
@@ -12,11 +12,13 @@ std::string common_chat_format_name(common_chat_format format) {
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING: return "DeepSeek R1 (extract reasoning)";
|
||||
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
|
||||
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)";
|
||||
default:
|
||||
throw std::runtime_error("Unknown chat format");
|
||||
}
|
||||
@@ -105,7 +107,6 @@ static common_chat_msg parse_json_tool_calls(
|
||||
std::sregex_iterator rend;
|
||||
std::sregex_iterator rit(it, end, function_regex);
|
||||
if (rit == rend) {
|
||||
fprintf(stderr, "No more tool calls found\n");
|
||||
result.content += std::string(it, end);
|
||||
break;
|
||||
}
|
||||
@@ -115,14 +116,21 @@ static common_chat_msg parse_json_tool_calls(
|
||||
|
||||
json arguments;
|
||||
if (!parse_json(it, end, arguments)) {
|
||||
throw std::runtime_error("Failed to parse json tool call arguments");
|
||||
throw std::runtime_error("Failed to parse json tool call arguments: " + input);
|
||||
}
|
||||
if (!std::regex_search(it, end, match, close_regex)) {
|
||||
throw std::runtime_error("Malformed input, missing closing pattern");
|
||||
throw std::runtime_error("Malformed input, missing closing pattern: " + input);
|
||||
}
|
||||
it = match.suffix().first;
|
||||
result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
|
||||
}
|
||||
|
||||
if (!result.tool_calls.empty()) {
|
||||
if (!string_strip(result.content).empty()) {
|
||||
LOG_WRN("Content found with tool calls: %s\n", result.content.c_str());
|
||||
}
|
||||
result.content = "";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -134,11 +142,11 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in
|
||||
result.role = "assistant";
|
||||
const auto process_tool_calls = [&](const json & tool_calls) {
|
||||
for (const auto & tool_call : tool_calls) {
|
||||
const auto & arguments = tool_call["arguments"];
|
||||
const auto & arguments = tool_call.at("arguments");
|
||||
result.tool_calls.push_back({
|
||||
tool_call["name"],
|
||||
tool_call.at("name"),
|
||||
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
tool_call.contains("id") ? tool_call["id"] : "",
|
||||
tool_call.contains("id") ? tool_call.at("id") : "",
|
||||
});
|
||||
}
|
||||
};
|
||||
@@ -155,7 +163,7 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in
|
||||
|
||||
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
|
||||
for (const auto & tool : tools) {
|
||||
if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) {
|
||||
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
|
||||
LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str());
|
||||
continue;
|
||||
}
|
||||
@@ -190,27 +198,27 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
|
||||
|
||||
auto tool_call_schemas = json::array();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
const auto & function = tool.at("function");
|
||||
auto tool_schema = json {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", function["name"]},
|
||||
{"const", function.at("name")},
|
||||
}},
|
||||
{"arguments", function["parameters"]},
|
||||
{"arguments", function.at("parameters")},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments"})},
|
||||
};
|
||||
if (function.contains("description")) {
|
||||
tool_schema["description"] = function["description"];
|
||||
tool_schema["description"] = function.at("description");
|
||||
}
|
||||
if (inputs.parallel_tool_calls) {
|
||||
tool_schema["properties"]["id"] = {
|
||||
tool_schema.at("properties")["id"] = {
|
||||
{"type", "string"},
|
||||
{"minLength", 4},
|
||||
};
|
||||
tool_schema["required"].push_back("id");
|
||||
tool_schema.at("required").push_back("id");
|
||||
}
|
||||
tool_call_schemas.emplace_back(tool_schema);
|
||||
});
|
||||
@@ -275,21 +283,21 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) {
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
if (data.contains("tool_calls")) {
|
||||
for (const auto & tool_call : data["tool_calls"]) {
|
||||
for (const auto & tool_call : data.at("tool_calls")) {
|
||||
result.tool_calls.push_back({
|
||||
tool_call["name"],
|
||||
tool_call["arguments"].dump(),
|
||||
tool_call.contains("id") ? tool_call["id"] : "",
|
||||
tool_call.at("name"),
|
||||
tool_call.at("arguments").dump(),
|
||||
tool_call.contains("id") ? tool_call.at("id") : "",
|
||||
});
|
||||
}
|
||||
} else if (data.contains("tool_call")) {
|
||||
result.tool_calls.push_back({
|
||||
data["tool_call"]["name"],
|
||||
data["tool_call"]["arguments"].dump(),
|
||||
data.at("tool_call").at("name"),
|
||||
data.at("tool_call").at("arguments").dump(),
|
||||
/* id= */ "",
|
||||
});
|
||||
} else if (data.contains("response")) {
|
||||
const auto & response = data["response"];
|
||||
const auto & response = data.at("response");
|
||||
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
|
||||
}
|
||||
return result;
|
||||
@@ -301,7 +309,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
const auto & function = tool.at("function");
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
@@ -309,9 +317,9 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
|
||||
// It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object.
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", function["name"]},
|
||||
{"const", function.at("name")},
|
||||
}},
|
||||
{"arguments", function["parameters"]},
|
||||
{"arguments", function.at("parameters")},
|
||||
{"id", {
|
||||
{"type", "string"},
|
||||
// Nemo's template expects a 9-character alphanumeric ID.
|
||||
@@ -346,7 +354,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
const auto & function = tool.at("function");
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
@@ -357,9 +365,9 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
|
||||
}},
|
||||
{"tool_name", {
|
||||
{"type", "string"},
|
||||
{"const", function["name"]},
|
||||
{"const", function.at("name")},
|
||||
}},
|
||||
{"parameters", function["parameters"]},
|
||||
{"parameters", function.at("parameters")},
|
||||
}},
|
||||
{"required", json::array({"tool_call_id", "tool_name", "parameters"})},
|
||||
});
|
||||
@@ -382,39 +390,65 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
|
||||
"<|END_THINKING|>",
|
||||
"<|END_ACTION|>",
|
||||
};
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
|
||||
auto adjusted_messages = json::array();
|
||||
for (const auto & msg : inputs.messages) {
|
||||
auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
|
||||
auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
|
||||
if (has_reasoning_content && has_tool_calls) {
|
||||
auto adjusted_message = msg;
|
||||
adjusted_message["tool_plan"] = msg.at("reasoning_content");
|
||||
adjusted_message.erase("reasoning_content");
|
||||
adjusted_messages.push_back(adjusted_message);
|
||||
} else {
|
||||
adjusted_messages.push_back(msg);
|
||||
}
|
||||
}
|
||||
data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {});
|
||||
data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING : COMMON_CHAT_FORMAT_COMMAND_R7B;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_command_r7b(const std::string & input) {
|
||||
static std::regex response_regex("<\\|START_RESPONSE\\|>([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>");
|
||||
static std::regex thought_action_regex("<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|><\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
|
||||
static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) {
|
||||
static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|>)([\\s\\S\\n\\r]*)");
|
||||
static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
|
||||
static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>");
|
||||
|
||||
std::smatch match;
|
||||
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
if (std::regex_match(input, match, response_regex)) {
|
||||
result.content = match[1].str();
|
||||
} else if (std::regex_match(input, match, thought_action_regex)) {
|
||||
result.tool_plan = match[1].str();
|
||||
auto actions_str = match[2].str();
|
||||
|
||||
std::string rest = input;
|
||||
|
||||
if (std::regex_match(rest, match, thought_regex)) {
|
||||
if (extract_reasoning) {
|
||||
result.reasoning_content = match[2].str();
|
||||
} else if (!match[2].str().empty()) {
|
||||
// Let the unparsed thinking tags through in content only if their insides aren't empty.
|
||||
result.content = match[1].str();
|
||||
}
|
||||
rest = match[3].str();
|
||||
}
|
||||
if (std::regex_match(rest, match, action_regex)) {
|
||||
auto actions_str = match[1].str();
|
||||
auto actions = json::parse(actions_str);
|
||||
for (const auto & action : actions) {
|
||||
result.tool_calls.push_back({
|
||||
/* .name = */ action["tool_name"],
|
||||
/* .arguments = */ action["parameters"].dump(),
|
||||
/* .id = */ action["tool_call_id"],
|
||||
/* .name = */ action.at("tool_name"),
|
||||
/* .arguments = */ action.at("parameters").dump(),
|
||||
/* .id = */ action.at("tool_call_id"),
|
||||
});
|
||||
}
|
||||
} else if (std::regex_match(rest, match, response_regex)) {
|
||||
auto response = match[1].str();
|
||||
result.content += response;
|
||||
} else {
|
||||
LOG_ERR("Failed to parse command_r output");
|
||||
result.content = input;
|
||||
result.content += rest;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
|
||||
if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
|
||||
if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
|
||||
throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
|
||||
}
|
||||
const auto & parameters_properties = parameters.at("properties");
|
||||
@@ -468,9 +502,9 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
||||
};
|
||||
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
|
||||
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
|
||||
@@ -546,34 +580,90 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo
|
||||
|
||||
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
tool_rules.push_back(builder.add_rule(name + "-call",
|
||||
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
|
||||
});
|
||||
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
|
||||
data.preserved_tokens = {
|
||||
"<|tool▁sep|>",
|
||||
"<|tool▁call▁end|>",
|
||||
};
|
||||
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
|
||||
}, grammar_options);
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != "required" && inputs.json_schema.is_null();
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
tool_rules.push_back(builder.add_rule(name + "-call",
|
||||
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
|
||||
"```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
|
||||
});
|
||||
// Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
|
||||
// so we accept common variants (then it's all constrained)
|
||||
builder.add_rule("root",
|
||||
"( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" ) "
|
||||
"(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
|
||||
"\"<|tool▁calls▁end|>\""
|
||||
" space");
|
||||
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({"<|tool_calls_begin|>", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({"<|tool calls begin|>", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({"<|tool\\_calls\\_begin|>", /* .at_start = */ false});
|
||||
data.preserved_tokens = {
|
||||
"<think>",
|
||||
"</think>",
|
||||
"<|tool▁sep|>",
|
||||
"<|tool▁calls▁end|",
|
||||
"<|tool▁call▁end|>",
|
||||
};
|
||||
}, grammar_options);
|
||||
}
|
||||
auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
|
||||
// Hacks to fix the official (broken) prompt.
|
||||
// It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
|
||||
// until the official template is fixed.
|
||||
if (tmpl.source().find("{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}") != std::string::npos) {
|
||||
// Don't leave the chat dangling after tool results
|
||||
if (string_ends_with(prompt, "<|tool▁outputs▁end|>")) {
|
||||
prompt += "<|end▁of▁sentence|>";
|
||||
if (inputs.add_generation_prompt) {
|
||||
prompt += "<|Assistant|>";
|
||||
}
|
||||
}
|
||||
// Fix up tool call delta example added by Minja
|
||||
prompt = std::regex_replace(
|
||||
prompt,
|
||||
std::regex("(<|tool▁call▁end|>)[\\s\\r\\n]*(<|tool▁outputs▁begin|>|<|User|>)"),
|
||||
"$1<|tool▁calls▁end|><|end▁of▁sentence|>$2");
|
||||
}
|
||||
data.prompt = prompt;
|
||||
data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1;
|
||||
data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING : COMMON_CHAT_FORMAT_DEEPSEEK_R1;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) {
|
||||
static std::regex trigger_regex("<|tool▁calls▁begin|>");
|
||||
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) {
|
||||
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
||||
static std::regex close_regex("```<|tool▁call▁end|>");
|
||||
return parse_json_tool_calls(input, trigger_regex, function_regex, close_regex);
|
||||
static std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
|
||||
static std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
|
||||
static std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>");
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
std::smatch match;
|
||||
if (std::regex_match(input, match, reasoning_content_regex)) {
|
||||
std::string rest;
|
||||
if (extract_reasoning) {
|
||||
msg.reasoning_content = string_strip(match[2].str());
|
||||
} else {
|
||||
msg.content = match[1].str();
|
||||
}
|
||||
rest = match[3].str();
|
||||
|
||||
if (std::regex_search(rest, match, tool_calls_regex)) {
|
||||
auto tool_calls = match[1].str();
|
||||
auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex);
|
||||
msg.tool_calls = std::move(msg2.tool_calls);
|
||||
} else {
|
||||
msg.content += std::string(rest.begin() + rest.find_first_not_of(" \r\n"), rest.end());
|
||||
}
|
||||
} else {
|
||||
msg.content = input;
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
@@ -583,20 +673,20 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
|
||||
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||
});
|
||||
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
const auto & function = tool.at("function");
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", function["name"]},
|
||||
{"const", function.at("name")},
|
||||
}},
|
||||
{"arguments", function["parameters"]},
|
||||
{"arguments", function.at("parameters")},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments", "id"})},
|
||||
});
|
||||
@@ -628,15 +718,15 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
|
||||
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> first_tool_rules;
|
||||
std::vector<std::string> subsequent_tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
||||
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
|
||||
@@ -716,9 +806,9 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
const auto & parameters = function["parameters"];
|
||||
std::string name = function["name"];
|
||||
const auto & function = tool.at("function");
|
||||
const auto & parameters = function.at("parameters");
|
||||
std::string name = function.at("name");
|
||||
if (name == "python" || name == "ipython") {
|
||||
if (!parameters.contains("type")) {
|
||||
throw std::runtime_error("Missing type in python tool");
|
||||
@@ -789,9 +879,9 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
tool_rules.push_back(builder.add_schema(name + "-call", {
|
||||
{"type", "object"},
|
||||
@@ -839,9 +929,9 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
|
||||
if (!parse_json(it, end, call)) {
|
||||
throw std::runtime_error("Failed to parse json tool call");
|
||||
}
|
||||
const auto & arguments = call["arguments"];
|
||||
const auto & arguments = call.at("arguments");
|
||||
result.tool_calls.push_back({
|
||||
call["name"],
|
||||
call.at("name"),
|
||||
arguments.dump(),
|
||||
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
/* id= */ "",
|
||||
@@ -884,47 +974,72 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
|
||||
}
|
||||
|
||||
common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
auto has_tools = !inputs.tools.is_null() && inputs.tool_choice != "none";
|
||||
LOG_DBG("[%s] has_tools=%s\n", __func__, has_tools ? "true" : "false");
|
||||
const auto & src = tmpl.source();
|
||||
const auto & caps = tmpl.original_caps();
|
||||
|
||||
if (has_tools && !inputs.grammar.empty()) {
|
||||
throw std::runtime_error("Cannot specify grammar with tools");
|
||||
if (inputs.tools.is_array()) {
|
||||
if (inputs.tool_choice != "none" && !inputs.grammar.empty()) {
|
||||
throw std::runtime_error("Cannot specify grammar with tools");
|
||||
}
|
||||
if (caps.supports_tool_calls && !caps.supports_tools) {
|
||||
LOG_WRN("Template supports tool calls but does not natively describe tools. The fallback behaviour used may produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
|
||||
}
|
||||
}
|
||||
|
||||
const auto & src = tmpl.source();
|
||||
// DeepSeek R1: use handler in all cases except json schema (thinking / tools).
|
||||
if (src.find("<|tool▁calls▁begin|>") != std::string::npos && inputs.json_schema.is_null()) {
|
||||
return common_chat_params_init_deepseek_r1(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Command R7B: : use handler in all cases except json schema (thinking / tools).
|
||||
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) {
|
||||
return common_chat_params_init_command_r7b(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Use generic handler when mixing tools + JSON schema.
|
||||
// TODO: support that mix in handlers below.
|
||||
if ((!inputs.tools.is_array() && inputs.json_schema.is_object())) {
|
||||
return common_chat_params_init_generic(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
|
||||
if (src.find(">>>all") != std::string::npos) {
|
||||
// Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when
|
||||
return common_chat_params_init_functionary_v3_2(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
|
||||
if (src.find(" functools[") != std::string::npos) {
|
||||
// Firefunction v2 requires datetime and functions in the context, even w/o tools.
|
||||
return common_chat_params_init_firefunction_v2(tmpl, inputs);
|
||||
}
|
||||
|
||||
if (!has_tools) {
|
||||
// Plain handler (no tools)
|
||||
if (inputs.tools.is_null() || inputs.tool_choice == "none") {
|
||||
return common_chat_params_init_without_tools(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
|
||||
if (src.find("<tool_call>") != std::string::npos) {
|
||||
return common_chat_params_init_hermes_2_pro(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Functionary v3.1 (w/ tools)
|
||||
if (src.find("<|start_header_id|>") != std::string::npos
|
||||
&& src.find("<function=") != std::string::npos) {
|
||||
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Llama 3.1, 3.2, 3.3 (w/ tools)
|
||||
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
||||
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
|
||||
}
|
||||
if (src.find("<|tool▁calls▁begin|>") != std::string::npos) {
|
||||
return common_chat_params_init_deepseek_r1(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Mistral Nemo (w/ tools)
|
||||
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
||||
return common_chat_params_init_mistral_nemo(tmpl, inputs);
|
||||
}
|
||||
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos) {
|
||||
return common_chat_params_init_command_r7b(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Generic fallback
|
||||
return common_chat_params_init_generic(tmpl, inputs);
|
||||
}
|
||||
|
||||
@@ -949,7 +1064,9 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
|
||||
return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true);
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
|
||||
return common_chat_parse_deepseek_r1(input);
|
||||
return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ false);
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING:
|
||||
return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ true);
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
|
||||
return common_chat_parse_functionary_v3_2(input);
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
|
||||
@@ -959,7 +1076,9 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
|
||||
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
|
||||
return common_chat_parse_firefunction_v2(input);
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B:
|
||||
return common_chat_parse_command_r7b(input);
|
||||
return common_chat_parse_command_r7b(input, /* extract_reasoning= */ false);
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING:
|
||||
return common_chat_parse_command_r7b(input, /* extract_reasoning= */ true);
|
||||
default:
|
||||
throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ struct common_chat_inputs {
|
||||
bool stream;
|
||||
std::string grammar;
|
||||
bool add_generation_prompt = true;
|
||||
bool extract_reasoning = true;
|
||||
};
|
||||
|
||||
enum common_chat_format {
|
||||
@@ -28,11 +29,13 @@ enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
|
||||
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
||||
@@ -140,6 +140,7 @@ struct common_params_sampling {
|
||||
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
||||
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float top_n_sigma = -1.00f;// -1.0 = disabled
|
||||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
bool ignore_eos = false;
|
||||
@@ -202,6 +203,11 @@ struct common_params_vocoder {
|
||||
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
||||
};
|
||||
|
||||
enum common_reasoning_format {
|
||||
COMMON_REASONING_FORMAT_NONE,
|
||||
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`
|
||||
};
|
||||
|
||||
struct common_params {
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
int32_t n_ctx = 4096; // context size
|
||||
@@ -292,6 +298,7 @@ struct common_params {
|
||||
bool kl_divergence = false; // compute KL divergence
|
||||
|
||||
bool usage = false; // print usage
|
||||
bool completion = false; // print source-able completion script
|
||||
bool use_color = false; // use color to distinguish generations and inputs
|
||||
bool special = false; // enable special token output
|
||||
bool interactive = false; // interactive mode
|
||||
@@ -346,6 +353,7 @@ struct common_params {
|
||||
std::string chat_template = ""; // NOLINT
|
||||
bool use_jinja = false; // NOLINT
|
||||
bool enable_chat_template = true;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
|
||||
std::vector<std::string> api_keys;
|
||||
|
||||
@@ -424,13 +432,13 @@ bool set_process_priority(enum ggml_sched_priority prio);
|
||||
//
|
||||
|
||||
#ifdef __GNUC__
|
||||
#ifdef __MINGW32__
|
||||
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||
# if defined(__MINGW32__) && !defined(__clang__)
|
||||
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||
# else
|
||||
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||
# endif
|
||||
#else
|
||||
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||
#endif
|
||||
#else
|
||||
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
|
||||
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
|
||||
#endif
|
||||
|
||||
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
|
||||
@@ -623,7 +631,7 @@ struct common_chat_msg {
|
||||
std::string role;
|
||||
std::string content;
|
||||
std::vector<common_tool_call> tool_calls;
|
||||
std::string tool_plan = "";
|
||||
std::string reasoning_content = "";
|
||||
};
|
||||
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "log.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <condition_variable>
|
||||
#include <cstdarg>
|
||||
#include <cstdio>
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#include "ggml.h" // for ggml_log_level
|
||||
|
||||
#define LOG_CLR_TO_EOL "\033[K\r"
|
||||
#define LOG_COL_DEFAULT "\033[0m"
|
||||
#define LOG_COL_BOLD "\033[1m"
|
||||
#define LOG_COL_RED "\033[31m"
|
||||
@@ -14,7 +15,7 @@
|
||||
|
||||
#ifndef __GNUC__
|
||||
# define LOG_ATTRIBUTE_FORMAT(...)
|
||||
#elif defined(__MINGW32__)
|
||||
#elif defined(__MINGW32__) && !defined(__clang__)
|
||||
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||
#else
|
||||
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||
|
||||
@@ -1385,6 +1385,13 @@ static std::string strip(const std::string & s) {
|
||||
return s.substr(start, end - start + 1);
|
||||
}
|
||||
|
||||
static std::string capitalize(const std::string & s) {
|
||||
if (s.empty()) return s;
|
||||
auto result = s;
|
||||
result[0] = std::toupper(result[0]);
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::string html_escape(const std::string & s) {
|
||||
std::string result;
|
||||
result.reserve(s.size());
|
||||
@@ -1462,6 +1469,9 @@ public:
|
||||
if (method->get_name() == "strip") {
|
||||
vargs.expectArgs("strip method", {0, 0}, {0, 0});
|
||||
return Value(strip(str));
|
||||
} else if (method->get_name() == "capitalize") {
|
||||
vargs.expectArgs("capitalize method", {0, 0}, {0, 0});
|
||||
return Value(capitalize(str));
|
||||
} else if (method->get_name() == "endswith") {
|
||||
vargs.expectArgs("endswith method", {1, 1}, {0, 0});
|
||||
auto suffix = vargs.args[0].get<std::string>();
|
||||
@@ -1792,7 +1802,7 @@ private:
|
||||
auto left = parseStringConcat();
|
||||
if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression");
|
||||
|
||||
static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not[\r\n\s]+in\b)");
|
||||
static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not\s+in\b)");
|
||||
static std::regex not_tok(R"(not\b)");
|
||||
std::string op_str;
|
||||
while (!(op_str = consumeToken(compare_tok)).empty()) {
|
||||
@@ -2171,7 +2181,7 @@ private:
|
||||
using TemplateTokenIterator = TemplateTokenVector::const_iterator;
|
||||
|
||||
std::vector<std::string> parseVarNames() {
|
||||
static std::regex varnames_regex(R"(((?:\w+)(?:[\r\n\s]*,[\r\n\s]*(?:\w+))*)[\r\n\s]*)");
|
||||
static std::regex varnames_regex(R"(((?:\w+)(?:\s*,\s*(?:\w+))*)\s*)");
|
||||
|
||||
std::vector<std::string> group;
|
||||
if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names");
|
||||
@@ -2194,13 +2204,13 @@ private:
|
||||
}
|
||||
|
||||
TemplateTokenVector tokenize() {
|
||||
static std::regex comment_tok(R"(\{#([-~]?)([\s\S\r\n]*?)([-~]?)#\})");
|
||||
static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})");
|
||||
static std::regex expr_open_regex(R"(\{\{([-~])?)");
|
||||
static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)");
|
||||
static std::regex block_open_regex(R"(^\{%([-~])?\s*)");
|
||||
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
|
||||
static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
|
||||
static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})");
|
||||
static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})");
|
||||
static std::regex expr_close_regex(R"(\s*([-~])?\}\})");
|
||||
static std::regex block_close_regex(R"(\s*([-~])?%\})");
|
||||
|
||||
TemplateTokenVector tokens;
|
||||
std::vector<std::string> group;
|
||||
@@ -2284,7 +2294,7 @@ private:
|
||||
auto post_space = parseBlockClose();
|
||||
tokens.push_back(std::make_unique<EndGenerationTemplateToken>(location, pre_space, post_space));
|
||||
} else if (keyword == "set") {
|
||||
static std::regex namespaced_var_regex(R"((\w+)[\s\n\r]*\.[\s\n\r]*(\w+))");
|
||||
static std::regex namespaced_var_regex(R"((\w+)\s*\.\s*(\w+))");
|
||||
|
||||
std::string ns;
|
||||
std::vector<std::string> var_names;
|
||||
@@ -2336,6 +2346,11 @@ private:
|
||||
throw std::runtime_error("Unexpected block: " + keyword);
|
||||
}
|
||||
} else if (std::regex_search(it, end, match, non_text_open_regex)) {
|
||||
if (!match.position()) {
|
||||
if (match[0] != "{#")
|
||||
throw std::runtime_error("Internal error: Expected a comment");
|
||||
throw std::runtime_error("Missing end of comment tag");
|
||||
}
|
||||
auto text_end = it + match.position();
|
||||
text = std::string(it, text_end);
|
||||
it = text_end;
|
||||
@@ -2400,7 +2415,7 @@ private:
|
||||
|
||||
auto text = text_token->text;
|
||||
if (post_space == SpaceHandling::Strip) {
|
||||
static std::regex trailing_space_regex(R"((\s|\r|\n)+$)");
|
||||
static std::regex trailing_space_regex(R"(\s+$)");
|
||||
text = std::regex_replace(text, trailing_space_regex, "");
|
||||
} else if (options.lstrip_blocks && it != end) {
|
||||
auto i = text.size();
|
||||
@@ -2410,7 +2425,7 @@ private:
|
||||
}
|
||||
}
|
||||
if (pre_space == SpaceHandling::Strip) {
|
||||
static std::regex leading_space_regex(R"(^(\s|\r|\n)+)");
|
||||
static std::regex leading_space_regex(R"(^\s+)");
|
||||
text = std::regex_replace(text, leading_space_regex, "");
|
||||
} else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast<ExpressionTemplateToken*>((*(it - 2)).get())) {
|
||||
if (text.length() > 0 && text[0] == '\n') {
|
||||
|
||||
@@ -134,11 +134,11 @@ std::string common_params_sampling::print() const {
|
||||
snprintf(result, sizeof(result),
|
||||
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
||||
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
||||
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
|
||||
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
|
||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
||||
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
||||
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
||||
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
|
||||
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
|
||||
mirostat, mirostat_eta, mirostat_tau);
|
||||
|
||||
return std::string(result);
|
||||
@@ -151,12 +151,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
|
||||
lparams.no_perf = params.no_perf;
|
||||
|
||||
std::vector<const char *> trigger_words;
|
||||
trigger_words.reserve(params.grammar_trigger_words.size());
|
||||
for (const auto & str : params.grammar_trigger_words) {
|
||||
trigger_words.push_back(str.word.c_str());
|
||||
}
|
||||
|
||||
struct llama_sampler * grmr;
|
||||
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
||||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
@@ -165,6 +159,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
} else {
|
||||
std::vector<const char *> trigger_words;
|
||||
trigger_words.reserve(params.grammar_trigger_words.size());
|
||||
for (const auto & str : params.grammar_trigger_words) {
|
||||
trigger_words.push_back(str.word.c_str());
|
||||
}
|
||||
|
||||
grmr = params.grammar_lazy
|
||||
? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root",
|
||||
trigger_words.data(), trigger_words.size(),
|
||||
@@ -188,45 +188,51 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
params.logit_bias.data()));
|
||||
|
||||
if (params.mirostat == 0) {
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
{
|
||||
std::vector<const char *> c_breakers;
|
||||
c_breakers.reserve(params.dry_sequence_breakers.size());
|
||||
for (const auto & str : params.dry_sequence_breakers) {
|
||||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
if (params.top_n_sigma >= 0) {
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
|
||||
} else {
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
{
|
||||
std::vector<const char *> c_breakers;
|
||||
c_breakers.reserve(params.dry_sequence_breakers.size());
|
||||
for (const auto & str : params.dry_sequence_breakers) {
|
||||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_INFILL:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_INFILL:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
}
|
||||
}
|
||||
}
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
||||
|
||||
@@ -9,7 +9,7 @@ struct common_speculative_params {
|
||||
int n_draft = 16; // max drafted tokens
|
||||
int n_reuse = 256;
|
||||
|
||||
float p_min = 0.9f; // min probabiliy required to accept a token in the draft
|
||||
float p_min = 0.9f; // min probability required to accept a token in the draft
|
||||
};
|
||||
|
||||
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
|
||||
|
||||
205
docs/backend/OPENCL.md
Normal file
205
docs/backend/OPENCL.md
Normal file
@@ -0,0 +1,205 @@
|
||||
# llama.cpp for OpenCL
|
||||
|
||||
- [Background](#background)
|
||||
- [OS](#os)
|
||||
- [Hardware](#hardware)
|
||||
- [DataType Supports](#datatype-supports)
|
||||
- [Model Preparation](#model-preparation)
|
||||
- [CMake Options](#cmake-options)
|
||||
- [Android](#android)
|
||||
- [Windows 11 Arm64](#windows-11-arm64)
|
||||
- [Known Issue](#known-issues)
|
||||
- [TODO](#todo)
|
||||
|
||||
## Background
|
||||
|
||||
OpenCL (Open Computing Language) is an open, royalty-free standard for cross-platform, parallel programming of diverse accelerators found in supercomputers, cloud servers, personal computers, mobile devices and embedded platforms. OpenCL specifies a programming language (based on C99) for programming these devices and application programming interfaces (APIs) to control the platform and execute programs on the compute devices. Similar to CUDA, OpenCL has been widely used to program GPUs and is supported by most GPU vendors.
|
||||
|
||||
### Llama.cpp + OpenCL
|
||||
|
||||
The llama.cpp OpenCL backend is designed to enable llama.cpp on **Qualcomm Adreno GPU** firstly via OpenCL. Thanks to the portabilty of OpenCL, the OpenCL backend can also run on certain Intel GPUs although the performance is not optimal.
|
||||
|
||||
## OS
|
||||
|
||||
| OS | Status | Verified |
|
||||
|---------|---------|------------------------------------------------|
|
||||
| Android | Support | Snapdragon 8 Gen 3, Snapdragon 8 Elite |
|
||||
| Windows | Support | Windows 11 Arm64 with Snapdragon X Elite |
|
||||
| Linux | Support | Ubuntu 22.04 WSL2 with Intel 12700H |
|
||||
|
||||
## Hardware
|
||||
|
||||
### Adreno GPU
|
||||
|
||||
**Verified devices**
|
||||
|
||||
| Adreno GPU | Status |
|
||||
|:------------------------------------:|:-------:|
|
||||
| Adreno 750 (Snapdragon 8 Gen 3) | Support |
|
||||
| Adreno 830 (Snapdragon 8 Elite) | Support |
|
||||
| Adreno X85 (Snapdragon X Elite) | Support |
|
||||
|
||||
## DataType Supports
|
||||
|
||||
| DataType | Status |
|
||||
|:----------------------:|:--------------------------:|
|
||||
| Q4_0 | Support |
|
||||
| Q6_K | Support, but not optimized |
|
||||
|
||||
## Model Preparation
|
||||
|
||||
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model prepration.
|
||||
|
||||
Currently we support `Q4_0` quantization and have optimize for it. To achieve best performance on Adreno GPU, add `--pure` to `llama-quantize`. For example,
|
||||
|
||||
```sh
|
||||
./llama-quantize --pure ggml-model-qwen2.5-3b-f16.gguf ggml-model-qwen-3b-Q4_0.gguf Q4_0
|
||||
```
|
||||
|
||||
Since `Q6_K` is also supported, `Q4_0` quantization without `--pure` will also work. However, the performance will be worse compared to pure `Q4_0` quantization.
|
||||
|
||||
## CMake Options
|
||||
|
||||
The OpenCL backend has the following CMake options that control the behavior of the backend.
|
||||
|
||||
| CMake options | Default value | Description |
|
||||
|:---------------------------------:|:--------------:|:------------------------------------------|
|
||||
| `GGML_OPENCL_EMBED_KERNELS` | `ON` | Embed OpenCL kernels into the executable. |
|
||||
| `GGML_OPENCL_USE_ADRENO_KERNELS` | `ON` | Use kernels optimized for Adreno. |
|
||||
|
||||
## Android
|
||||
|
||||
Ubuntu 22.04 is used for targeting Android. Make sure the following tools are accessible from command line,
|
||||
|
||||
* Git
|
||||
* CMake 3.29
|
||||
* Ninja
|
||||
* Python3
|
||||
|
||||
### I. Setup Environment
|
||||
|
||||
1. **Install NDK**
|
||||
|
||||
```sh
|
||||
cd ~
|
||||
wget https://dl.google.com/android/repository/commandlinetools-linux-8512546_latest.zip && \
|
||||
unzip commandlinetools-linux-8512546_latest.zip && \
|
||||
mkdir -p ~/android-sdk/cmdline-tools && \
|
||||
mv cmdline-tools latest && \
|
||||
mv latest ~/android-sdk/cmdline-tools/ && \
|
||||
rm -rf commandlinetools-linux-8512546_latest.zip
|
||||
|
||||
yes | ~/android-sdk/cmdline-tools/latest/bin/sdkmanager "ndk;26.3.11579264"
|
||||
```
|
||||
|
||||
2. **Install OpenCL Headers and Library**
|
||||
|
||||
```sh
|
||||
mkdir -p ~/dev/llm
|
||||
cd ~/dev/llm
|
||||
|
||||
git clone https://github.com/KhronosGroup/OpenCL-Headers && \
|
||||
cd OpenCL-Headers && \
|
||||
cp -r CL ~/android-sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
|
||||
|
||||
cd ~/dev/llm
|
||||
|
||||
git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader && \
|
||||
cd OpenCL-ICD-Loader && \
|
||||
mkdir build_ndk26 && cd build_ndk26 && \
|
||||
cmake .. -G Ninja -DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_TOOLCHAIN_FILE=$HOME/android-sdk/ndk/26.3.11579264/build/cmake/android.toolchain.cmake \
|
||||
-DOPENCL_ICD_LOADER_HEADERS_DIR=$HOME/android-sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include \
|
||||
-DANDROID_ABI=arm64-v8a \
|
||||
-DANDROID_PLATFORM=24 \
|
||||
-DANDROID_STL=c++_shared && \
|
||||
ninja && \
|
||||
cp libOpenCL.so ~/android-sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android
|
||||
```
|
||||
|
||||
### II. Build llama.cpp
|
||||
|
||||
```sh
|
||||
cd ~/dev/llm
|
||||
|
||||
git clone https://github.com/ggerganov/llama.cpp && \
|
||||
cd llama.cpp && \
|
||||
mkdir build-android && cd build-android
|
||||
|
||||
cmake .. -G Ninja \
|
||||
-DCMAKE_TOOLCHAIN_FILE=$HOME/android-sdk/ndk/26.3.11579264/build/cmake/android.toolchain.cmake \
|
||||
-DANDROID_ABI=arm64-v8a \
|
||||
-DANDROID_PLATFORM=android-28 \
|
||||
-DBUILD_SHARED_LIBS=OFF \
|
||||
-DGGML_OPENCL=ON
|
||||
|
||||
ninja
|
||||
```
|
||||
|
||||
## Windows 11 Arm64
|
||||
|
||||
A Snapdragon X Elite device with Windows 11 Arm64 is used. Make sure the following tools are accessible from command line,
|
||||
|
||||
* Git
|
||||
* CMake 3.29
|
||||
* Clang 19
|
||||
* Ninja
|
||||
* Visual Studio 2022
|
||||
|
||||
Powershell is used for the following instructions.
|
||||
|
||||
### I. Setup Environment
|
||||
|
||||
1. **Install OpenCL Headers and Library**
|
||||
|
||||
```powershell
|
||||
mkdir -p ~/dev/llm
|
||||
|
||||
cd ~/dev/llm
|
||||
git clone https://github.com/KhronosGroup/OpenCL-Headers && cd OpenCL-Headers
|
||||
mkdir build && cd build
|
||||
cmake .. -G Ninja `
|
||||
-DBUILD_TESTING=OFF `
|
||||
-DOPENCL_HEADERS_BUILD_TESTING=OFF `
|
||||
-DOPENCL_HEADERS_BUILD_CXX_TESTS=OFF `
|
||||
-DCMAKE_INSTALL_PREFIX="$HOME/dev/llm/opencl"
|
||||
cmake --build . --target install
|
||||
|
||||
cd ~/dev/llm
|
||||
git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader && cd OpenCL-ICD-Loader
|
||||
mkdir build && cd build
|
||||
cmake .. -G Ninja `
|
||||
-DCMAKE_BUILD_TYPE=Release `
|
||||
-DCMAKE_PREFIX_PATH="$HOME/dev/llm/opencl" `
|
||||
-DCMAKE_INSTALL_PREFIX="$HOME/dev/llm/opencl"
|
||||
cmake --build . --target install
|
||||
```
|
||||
|
||||
### II. Build llama.cpp
|
||||
|
||||
```powershell
|
||||
|
||||
mkdir -p ~/dev/llm
|
||||
cd ~/dev/llm
|
||||
|
||||
git clone https://github.com/ggerganov/llama.cpp && cd llama.cpp
|
||||
mkdir build && cd build
|
||||
|
||||
cmake .. -G Ninja `
|
||||
-DCMAKE_TOOLCHAIN_FILE="$HOME/dev/llm/llama.cpp/cmake/arm64-windows-llvm.cmake" `
|
||||
-DCMAKE_BUILD_TYPE=Release `
|
||||
-DCMAKE_PREFIX_PATH="$HOME/dev/llm/opencl" `
|
||||
-DBUILD_SHARED_LIBS=OFF `
|
||||
-DGGML_OPENCL=ON
|
||||
ninja
|
||||
```
|
||||
|
||||
## Known Issues
|
||||
|
||||
- Qwen2.5 0.5B model produces gibberish output with Adreno kernels.
|
||||
|
||||
## TODO
|
||||
|
||||
- Fix Qwen2.5 0.5B
|
||||
- Optimization for Q6_K
|
||||
- Support and optimization for Q4_K
|
||||
@@ -46,7 +46,7 @@ cmake --build build --config Release
|
||||
```
|
||||
|
||||
- Building for Windows (x86, x64 and arm64) with MSVC or clang as compilers:
|
||||
- Install Visual Studio 2022, e.g. via the [Community Edition](https://visualstudio.microsoft.com/de/vs/community/). In the installer, select at least the following options (this also automatically installs the required additional tools like CMake,...):
|
||||
- Install Visual Studio 2022, e.g. via the [Community Edition](https://visualstudio.microsoft.com/vs/community/). In the installer, select at least the following options (this also automatically installs the required additional tools like CMake,...):
|
||||
- Tab Workload: Desktop-development with C++
|
||||
- Tab Components (select quickly via search): C++-_CMake_ Tools for Windows, _Git_ for Windows, C++-_Clang_ Compiler for Windows, MS-Build Support for LLVM-Toolset (clang)
|
||||
- Please remember to always use a Developer Command Prompt / PowerShell for VS2022 for git, build, test
|
||||
|
||||
@@ -69,7 +69,7 @@ You may want to pass in some different `ARGS`, depending on the CUDA environment
|
||||
|
||||
The defaults are:
|
||||
|
||||
- `CUDA_VERSION` set to `12.6.0`
|
||||
- `CUDA_VERSION` set to `12.4.0`
|
||||
- `CUDA_DOCKER_ARCH` set to the cmake build default, which includes all the supported architectures
|
||||
|
||||
The resulting images, are essentially the same as the non-CUDA images:
|
||||
@@ -104,7 +104,7 @@ You may want to pass in some different `ARGS`, depending on the MUSA environment
|
||||
|
||||
The defaults are:
|
||||
|
||||
- `MUSA_VERSION` set to `rc3.1.0`
|
||||
- `MUSA_VERSION` set to `rc3.1.1`
|
||||
|
||||
The resulting images, are essentially the same as the non-MUSA images:
|
||||
|
||||
|
||||
@@ -13,13 +13,15 @@ cmake -B build -DLLAMA_LLGUIDANCE=ON
|
||||
make -C build -j
|
||||
```
|
||||
|
||||
For Windows use `cmake --build build --config Release` instead of `make`.
|
||||
|
||||
This requires the Rust compiler and the `cargo` tool to be [installed](https://www.rust-lang.org/tools/install).
|
||||
|
||||
## Interface
|
||||
|
||||
There are no new command-line arguments or modifications to `common_params`. When enabled, grammars starting with `%llguidance` are passed to LLGuidance instead of the [current](../grammars/README.md) llama.cpp grammars. Additionally, JSON Schema requests (e.g., using the `-j` argument in `llama-cli`) are also passed to LLGuidance.
|
||||
|
||||
For your existing GBNF grammars, you can use [gbnf_to_lark.py script](https://github.com/guidance-ai/llguidance/blob/main/scripts/gbnf_to_lark.py) to convert them to LLGuidance Lark-like format.
|
||||
For your existing GBNF grammars, you can use [gbnf_to_lark.py script](https://github.com/guidance-ai/llguidance/blob/main/python/llguidance/gbnf_to_lark.py) to convert them to LLGuidance Lark-like format.
|
||||
|
||||
## Performance
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
|
||||
@@ -876,8 +876,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
struct test {
|
||||
static const std::string build_commit;
|
||||
static const int build_number;
|
||||
static const std::string cpu_info;
|
||||
static const std::string gpu_info;
|
||||
const std::string cpu_info;
|
||||
const std::string gpu_info;
|
||||
std::string model_filename;
|
||||
std::string model_type;
|
||||
uint64_t model_size;
|
||||
@@ -903,7 +903,10 @@ struct test {
|
||||
std::string test_time;
|
||||
std::vector<uint64_t> samples_ns;
|
||||
|
||||
test(const cmd_params_instance & inst, const llama_model * lmodel, const llama_context * ctx) {
|
||||
test(const cmd_params_instance & inst, const llama_model * lmodel, const llama_context * ctx) :
|
||||
cpu_info(get_cpu_info()),
|
||||
gpu_info(get_gpu_info()) {
|
||||
|
||||
model_filename = inst.model;
|
||||
char buf[128];
|
||||
llama_model_desc(lmodel, buf, sizeof(buf));
|
||||
@@ -1058,8 +1061,6 @@ struct test {
|
||||
|
||||
const std::string test::build_commit = LLAMA_COMMIT;
|
||||
const int test::build_number = LLAMA_BUILD_NUMBER;
|
||||
const std::string test::cpu_info = get_cpu_info();
|
||||
const std::string test::gpu_info = get_gpu_info();
|
||||
|
||||
struct printer {
|
||||
virtual ~printer() {}
|
||||
|
||||
@@ -37,7 +37,7 @@ Once downloaded, place your model in the models folder in llama.cpp.
|
||||
|
||||
##### Infinite text from a starting prompt (you can use `Ctrl-C` to stop it):
|
||||
```bash
|
||||
./llama-cli -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1
|
||||
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1
|
||||
```
|
||||
|
||||
### Windows:
|
||||
@@ -265,6 +265,14 @@ Being experimental and unique, XTC is disabled by default. The recommended combi
|
||||
|
||||
Example usage: `--xtc-probability 0.5 --xtc-threshold 0.1`
|
||||
|
||||
### Top-nσ Sampling
|
||||
|
||||
- `--top-nsigma N`: Limit the next token selection to a subset of tokens with pre-softmax logits that are within n * σ less than the max logit (default: -1, -1 = disabled).
|
||||
|
||||
Top-nσ sampling is a text generation method that selects tokens based on a statistical threshold in pre-softmax logits. It works by only sampling from tokens with logits that are within n * σ of the maximum logit. This method helps maintain a stable sampling space regardless of temperature scaling, allowing it to perform well on reasoning tasks even in high temperatures. Without complex probability manipulation, it efficiently filters tokens directly on the pre-softmax logits. A higher value for top-nsigma (e.g., 5) will take more noisy tokens into consideration, while a lower value (e.g., 1) will focous on the more informative region of the sampling space.
|
||||
|
||||
Example usage: `--top-nsigma 1`
|
||||
|
||||
### Logit Bias
|
||||
|
||||
- `-l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS`: Modify the likelihood of a token appearing in the generated text completion.
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
|
||||
@@ -535,8 +535,7 @@ class HttpClient {
|
||||
|
||||
static void print_progress(const std::string & progress_prefix, const std::string & progress_bar,
|
||||
const std::string & progress_suffix) {
|
||||
printe("\r%*s\r%s%s| %s", get_terminal_width(), " ", progress_prefix.c_str(), progress_bar.c_str(),
|
||||
progress_suffix.c_str());
|
||||
printe("\r" LOG_CLR_TO_EOL "%s%s| %s", progress_prefix.c_str(), progress_bar.c_str(), progress_suffix.c_str());
|
||||
}
|
||||
// Function to write data to a file
|
||||
static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) {
|
||||
@@ -797,16 +796,13 @@ class LlamaData {
|
||||
llama_model_ptr initialize_model(Opt & opt) {
|
||||
ggml_backend_load_all();
|
||||
resolve_model(opt.model_);
|
||||
printe(
|
||||
"\r%*s"
|
||||
"\rLoading model",
|
||||
get_terminal_width(), " ");
|
||||
printe("\r" LOG_CLR_TO_EOL "Loading model");
|
||||
llama_model_ptr model(llama_model_load_from_file(opt.model_.c_str(), opt.model_params));
|
||||
if (!model) {
|
||||
printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str());
|
||||
}
|
||||
|
||||
printe("\r%*s\r", static_cast<int>(sizeof("Loading model")), " ");
|
||||
printe("\r" LOG_CLR_TO_EOL);
|
||||
return model;
|
||||
}
|
||||
|
||||
@@ -969,10 +965,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
|
||||
static int read_user_input(std::string & user_input) {
|
||||
static const char * prompt_prefix = "> ";
|
||||
#ifdef WIN32
|
||||
printf(
|
||||
"\r%*s"
|
||||
"\r" LOG_COL_DEFAULT "%s",
|
||||
get_terminal_width(), " ", prompt_prefix);
|
||||
printf("\r" LOG_CLR_TO_EOL LOG_COL_DEFAULT "%s", prompt_prefix);
|
||||
|
||||
std::getline(std::cin, user_input);
|
||||
if (std::cin.eof()) {
|
||||
|
||||
@@ -127,6 +127,7 @@ The project is under active development, and we are [looking for feedback and co
|
||||
| `--grammar-file FNAME` | file to read grammar from |
|
||||
| `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object<br/>For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead |
|
||||
| `--jinja` | Enable experimental Jinja templating engine (required for tool use) |
|
||||
| `--reasoning-format FORMAT` | Controls extraction of model thinking traces and the format / field in which they are returned (default: `deepseek`; allowed values: `deepseek`, `none`; requires `--jinja`). `none` will leave thinking traces inline in `message.content` in a model-specific format, while `deepseek` will return them separately under `message.reasoning_content` |
|
||||
|
||||
**Example-specific params**
|
||||
|
||||
@@ -1136,61 +1137,252 @@ curl http://localhost:8080/v1/chat/completions \
|
||||
|
||||
| Template | Format |
|
||||
|----------|--------|
|
||||
| CohereForAI-c4ai-command-r-plus-default.jinja | generic tool calls |
|
||||
| CohereForAI-c4ai-command-r-plus-rag.jinja | generic tool calls |
|
||||
| CohereForAI-c4ai-command-r-plus-tool_use.jinja | generic tool calls |
|
||||
| MiniMaxAI-MiniMax-Text-01.jinja | generic tool calls |
|
||||
| NexaAIDev-Octopus-v2.jinja | generic tool calls |
|
||||
| NousResearch-Hermes-2-Pro-Llama-3-8B-default.jinja | generic tool calls |
|
||||
| NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja | hermes 2 pro tool calls |
|
||||
| NousResearch-Hermes-2-Pro-Mistral-7B-default.jinja | generic tool calls |
|
||||
| NousResearch-Hermes-2-Pro-Mistral-7B-tool_use.jinja | hermes 2 pro tool calls |
|
||||
| NousResearch-Hermes-3-Llama-3.1-70B-default.jinja | generic tool calls |
|
||||
| NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja | hermes 2 pro tool calls |
|
||||
| OrionStarAI-Orion-14B-Chat.jinja | generic tool calls |
|
||||
| Qwen-QwQ-32B-Preview.jinja | hermes 2 pro tool calls |
|
||||
| Qwen-Qwen2-7B-Instruct.jinja | generic tool calls |
|
||||
| Qwen-Qwen2-VL-7B-Instruct.jinja | generic tool calls |
|
||||
| Qwen-Qwen2.5-7B-Instruct.jinja | hermes 2 pro tool calls |
|
||||
| Qwen-Qwen2.5-Math-7B-Instruct.jinja | hermes 2 pro tool calls |
|
||||
| TheBloke-FusionNet_34Bx2_MoE-AWQ.jinja | generic tool calls |
|
||||
| abacusai-Fewshot-Metamath-OrcaVicuna-Mistral.jinja | generic tool calls |
|
||||
| bofenghuang-vigogne-2-70b-chat.jinja | generic tool calls |
|
||||
| databricks-dbrx-instruct.jinja | generic tool calls |
|
||||
| deepseek-ai-DeepSeek-Coder-V2-Instruct.jinja | generic tool calls |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja | deepseek r1 tool calls |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja | deepseek r1 tool calls |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Qwen-7B.jinja | deepseek r1 tool calls |
|
||||
| deepseek-ai-DeepSeek-V2.5.jinja | deepseek r1 tool calls |
|
||||
| deepseek-ai-deepseek-coder-33b-instruct.jinja | generic tool calls |
|
||||
| google-gemma-2-2b-it.jinja | generic tool calls |
|
||||
| google-gemma-7b-it.jinja | generic tool calls |
|
||||
| indischepartij-MiniCPM-3B-OpenHermes-2.5-v2.jinja | generic tool calls |
|
||||
| mattshumer-Reflection-Llama-3.1-70B.jinja | generic tool calls |
|
||||
| meetkai-functionary-medium-v3.2.jinja | functionary v3.2 tool calls |
|
||||
| meta-llama-Llama-3.1-8B-Instruct.jinja | llama 3.x tool calls (w/ builtin tools) |
|
||||
| meta-llama-Llama-3.2-3B-Instruct.jinja | llama 3.x tool calls |
|
||||
| meta-llama-Llama-3.3-70B-Instruct.jinja | llama 3.x tool calls (w/ builtin tools) |
|
||||
| meta-llama-Meta-Llama-3.1-8B-Instruct.jinja | llama 3.x tool calls (w/ builtin tools) |
|
||||
| microsoft-Phi-3-medium-4k-instruct.jinja | generic tool calls |
|
||||
| microsoft-Phi-3-mini-4k-instruct.jinja | generic tool calls |
|
||||
| microsoft-Phi-3-small-8k-instruct.jinja | generic tool calls |
|
||||
| microsoft-Phi-3.5-mini-instruct.jinja | generic tool calls |
|
||||
| microsoft-Phi-3.5-vision-instruct.jinja | generic tool calls |
|
||||
| mistralai-Mistral-7B-Instruct-v0.2.jinja | generic tool calls |
|
||||
| mistralai-Mistral-Large-Instruct-2407.jinja | mistral nemo tool calls |
|
||||
| mistralai-Mistral-Large-Instruct-2411.jinja | generic tool calls |
|
||||
| mistralai-Mistral-Nemo-Instruct-2407.jinja | mistral nemo tool calls |
|
||||
| mistralai-Mixtral-8x7B-Instruct-v0.1.jinja | generic tool calls |
|
||||
| mlabonne-AlphaMonarch-7B.jinja | generic tool calls |
|
||||
| nvidia-Llama-3.1-Nemotron-70B-Instruct-HF.jinja | llama 3.x tool calls (w/ builtin tools) |
|
||||
| openchat-openchat-3.5-0106.jinja | generic tool calls |
|
||||
| teknium-OpenHermes-2.5-Mistral-7B.jinja | generic tool calls |
|
||||
| Almawave-Velvet-14B.jinja | Hermes 2 Pro |
|
||||
| AtlaAI-Selene-1-Mini-Llama-3.1-8B.jinja | Llama 3.x |
|
||||
| CohereForAI-aya-expanse-8b.jinja | Generic |
|
||||
| CohereForAI-c4ai-command-r-plus-default.jinja | Generic |
|
||||
| CohereForAI-c4ai-command-r-plus-rag.jinja | Generic |
|
||||
| CohereForAI-c4ai-command-r-plus-tool_use.jinja | Generic |
|
||||
| CohereForAI-c4ai-command-r7b-12-2024-default.jinja | Command R7B (extract reasoning) |
|
||||
| CohereForAI-c4ai-command-r7b-12-2024-rag.jinja | Command R7B (extract reasoning) |
|
||||
| CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja | Command R7B (extract reasoning) |
|
||||
| CohereForAI-c4ai-command-r7b-12-2024.jinja | Generic |
|
||||
| DavieLion-Llama-3.2-1B-SPIN-iter3.jinja | Generic |
|
||||
| Delta-Vector-Rei-12B.jinja | Mistral Nemo |
|
||||
| EpistemeAI-Mistral-Nemo-Instruct-12B-Philosophy-Math.jinja | Mistral Nemo |
|
||||
| FlofloB-83k_continued_pretraining_Qwen2.5-0.5B-Instruct_Unsloth_merged_16bit.jinja | Hermes 2 Pro |
|
||||
| FlofloB-test_continued_pretraining_Phi-3-mini-4k-instruct_Unsloth_merged_16bit.jinja | Generic |
|
||||
| HelpingAI-HAI-SER.jinja | Generic |
|
||||
| HuggingFaceTB-SmolLM2-1.7B-Instruct.jinja | Generic |
|
||||
| HuggingFaceTB-SmolLM2-135M-Instruct.jinja | Generic |
|
||||
| HuggingFaceTB-SmolLM2-360M-Instruct.jinja | Generic |
|
||||
| INSAIT-Institute-BgGPT-Gemma-2-27B-IT-v1.0.jinja | Generic |
|
||||
| Ihor-Text2Graph-R1-Qwen2.5-0.5b.jinja | Hermes 2 Pro |
|
||||
| Infinigence-Megrez-3B-Instruct.jinja | Generic |
|
||||
| Josephgflowers-TinyLlama_v1.1_math_code-world-test-1.jinja | Generic |
|
||||
| LGAI-EXAONE-EXAONE-3.5-2.4B-Instruct.jinja | Generic |
|
||||
| LGAI-EXAONE-EXAONE-3.5-7.8B-Instruct.jinja | Generic |
|
||||
| LatitudeGames-Wayfarer-12B.jinja | Generic |
|
||||
| Magpie-Align-Llama-3-8B-Magpie-Align-v0.1.jinja | Generic |
|
||||
| Magpie-Align-Llama-3.1-8B-Magpie-Align-v0.1.jinja | Generic |
|
||||
| MaziyarPanahi-calme-3.2-instruct-78b.jinja | Generic |
|
||||
| MiniMaxAI-MiniMax-Text-01.jinja | Generic |
|
||||
| MiniMaxAI-MiniMax-VL-01.jinja | Generic |
|
||||
| NaniDAO-deepseek-r1-qwen-2.5-32B-ablated.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| NexaAIDev-Octopus-v2.jinja | Generic |
|
||||
| NousResearch-Hermes-2-Pro-Llama-3-8B-default.jinja | Generic |
|
||||
| NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja | Hermes 2 Pro |
|
||||
| NousResearch-Hermes-2-Pro-Mistral-7B-default.jinja | Generic |
|
||||
| NousResearch-Hermes-2-Pro-Mistral-7B-tool_use.jinja | Hermes 2 Pro |
|
||||
| NousResearch-Hermes-3-Llama-3.1-70B-default.jinja | Generic |
|
||||
| NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja | Hermes 2 Pro |
|
||||
| NovaSky-AI-Sky-T1-32B-Flash.jinja | Hermes 2 Pro |
|
||||
| NovaSky-AI-Sky-T1-32B-Preview.jinja | Hermes 2 Pro |
|
||||
| OnlyCheeini-greesychat-turbo.jinja | Generic |
|
||||
| Orenguteng-Llama-3.1-8B-Lexi-Uncensored-V2.jinja | Llama 3.x |
|
||||
| OrionStarAI-Orion-14B-Chat.jinja | Generic |
|
||||
| PowerInfer-SmallThinker-3B-Preview.jinja | Generic |
|
||||
| PrimeIntellect-INTELLECT-1-Instruct.jinja | Generic |
|
||||
| Qwen-QVQ-72B-Preview.jinja | Generic |
|
||||
| Qwen-QwQ-32B-Preview.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen1.5-7B-Chat.jinja | Generic |
|
||||
| Qwen-Qwen2-7B-Instruct.jinja | Generic |
|
||||
| Qwen-Qwen2-VL-72B-Instruct.jinja | Generic |
|
||||
| Qwen-Qwen2-VL-7B-Instruct.jinja | Generic |
|
||||
| Qwen-Qwen2.5-0.5B.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-1.5B-Instruct.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-14B-Instruct-1M.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-14B.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-32B-Instruct.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-32B.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-3B-Instruct.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-72B-Instruct.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-7B-Instruct-1M.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-7B-Instruct.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-7B.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-Coder-32B-Instruct.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-Coder-7B-Instruct.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-Math-1.5B.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-Math-7B-Instruct.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-VL-3B-Instruct.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-VL-72B-Instruct.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-VL-7B-Instruct.jinja | Hermes 2 Pro |
|
||||
| RWKV-Red-Team-ARWKV-7B-Preview-0.1.jinja | Hermes 2 Pro |
|
||||
| SakanaAI-TinySwallow-1.5B-Instruct.jinja | Hermes 2 Pro |
|
||||
| SakanaAI-TinySwallow-1.5B.jinja | Hermes 2 Pro |
|
||||
| Sao10K-70B-L3.3-Cirrus-x1.jinja | Llama 3.x |
|
||||
| SentientAGI-Dobby-Mini-Leashed-Llama-3.1-8B.jinja | Llama 3.x |
|
||||
| SentientAGI-Dobby-Mini-Unhinged-Llama-3.1-8B.jinja | Llama 3.x |
|
||||
| Steelskull-L3.3-Damascus-R1.jinja | Llama 3.x |
|
||||
| Steelskull-L3.3-MS-Nevoria-70b.jinja | Llama 3.x |
|
||||
| Steelskull-L3.3-Nevoria-R1-70b.jinja | Llama 3.x |
|
||||
| THUDM-glm-4-9b-chat.jinja | Generic |
|
||||
| THUDM-glm-edge-1.5b-chat.jinja | Generic |
|
||||
| Tarek07-Progenitor-V1.1-LLaMa-70B.jinja | Llama 3.x |
|
||||
| TheBloke-FusionNet_34Bx2_MoE-AWQ.jinja | Generic |
|
||||
| TinyLlama-TinyLlama-1.1B-Chat-v1.0.jinja | Generic |
|
||||
| UCLA-AGI-Mistral7B-PairRM-SPPO-Iter3.jinja | Generic |
|
||||
| ValiantLabs-Llama3.1-8B-Enigma.jinja | Llama 3.x |
|
||||
| abacusai-Fewshot-Metamath-OrcaVicuna-Mistral.jinja | Generic |
|
||||
| ai21labs-AI21-Jamba-1.5-Large.jinja | Generic |
|
||||
| allenai-Llama-3.1-Tulu-3-405B-SFT.jinja | Generic |
|
||||
| allenai-Llama-3.1-Tulu-3-405B.jinja | Generic |
|
||||
| allenai-Llama-3.1-Tulu-3-8B.jinja | Generic |
|
||||
| arcee-ai-Virtuoso-Lite.jinja | Hermes 2 Pro |
|
||||
| arcee-ai-Virtuoso-Medium-v2.jinja | Hermes 2 Pro |
|
||||
| arcee-ai-Virtuoso-Small-v2.jinja | Hermes 2 Pro |
|
||||
| avemio-GRAG-NEMO-12B-ORPO-HESSIAN-AI.jinja | Generic |
|
||||
| bespokelabs-Bespoke-Stratos-7B.jinja | Hermes 2 Pro |
|
||||
| bfuzzy1-acheron-m1a-llama.jinja | Generic |
|
||||
| bofenghuang-vigogne-2-70b-chat.jinja | Generic |
|
||||
| bytedance-research-UI-TARS-72B-DPO.jinja | Generic |
|
||||
| bytedance-research-UI-TARS-7B-DPO.jinja | Generic |
|
||||
| bytedance-research-UI-TARS-7B-SFT.jinja | Generic |
|
||||
| carsenk-phi3.5_mini_exp_825_uncensored.jinja | Generic |
|
||||
| cyberagent-DeepSeek-R1-Distill-Qwen-14B-Japanese.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| cyberagent-DeepSeek-R1-Distill-Qwen-32B-Japanese.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| databricks-dbrx-instruct.jinja | Generic |
|
||||
| deepseek-ai-DeepSeek-Coder-V2-Instruct.jinja | Generic |
|
||||
| deepseek-ai-DeepSeek-Coder-V2-Lite-Base.jinja | Generic |
|
||||
| deepseek-ai-DeepSeek-Coder-V2-Lite-Instruct.jinja | Generic |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Llama-70B.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Qwen-1.5B.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Qwen-14B.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Qwen-7B.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| deepseek-ai-DeepSeek-R1-Zero.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| deepseek-ai-DeepSeek-R1.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| deepseek-ai-DeepSeek-V2-Lite.jinja | Generic |
|
||||
| deepseek-ai-DeepSeek-V2.5.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| deepseek-ai-DeepSeek-V3.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| deepseek-ai-deepseek-coder-33b-instruct.jinja | Generic |
|
||||
| deepseek-ai-deepseek-coder-6.7b-instruct.jinja | Generic |
|
||||
| deepseek-ai-deepseek-coder-7b-instruct-v1.5.jinja | Generic |
|
||||
| deepseek-ai-deepseek-llm-67b-chat.jinja | Generic |
|
||||
| deepseek-ai-deepseek-llm-7b-chat.jinja | Generic |
|
||||
| dicta-il-dictalm2.0-instruct.jinja | Generic |
|
||||
| ehristoforu-Falcon3-8B-Franken-Basestruct.jinja | Hermes 2 Pro |
|
||||
| fireworks-ai-llama-3-firefunction-v2.jinja | FireFunction v2 |
|
||||
| godlikehhd-alpaca_data_sampled_ifd_new_5200.jinja | Hermes 2 Pro |
|
||||
| godlikehhd-alpaca_data_score_max_0.7_2600.jinja | Hermes 2 Pro |
|
||||
| google-gemma-2-27b-it.jinja | Generic |
|
||||
| google-gemma-2-2b-it.jinja | Generic |
|
||||
| google-gemma-2-2b-jpn-it.jinja | Generic |
|
||||
| google-gemma-7b-it.jinja | Generic |
|
||||
| huihui-ai-DeepSeek-R1-Distill-Llama-70B-abliterated.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| huihui-ai-DeepSeek-R1-Distill-Llama-8B-abliterated.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| huihui-ai-DeepSeek-R1-Distill-Qwen-14B-abliterated-v2.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| huihui-ai-DeepSeek-R1-Distill-Qwen-32B-abliterated.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| huihui-ai-DeepSeek-R1-Distill-Qwen-7B-abliterated-v2.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| huihui-ai-Qwen2.5-14B-Instruct-1M-abliterated.jinja | Hermes 2 Pro |
|
||||
| ibm-granite-granite-3.1-8b-instruct.jinja | Generic |
|
||||
| indischepartij-MiniCPM-3B-OpenHermes-2.5-v2.jinja | Generic |
|
||||
| inflatebot-MN-12B-Mag-Mell-R1.jinja | Generic |
|
||||
| jinaai-ReaderLM-v2.jinja | Generic |
|
||||
| kms7530-chemeng_qwen-math-7b_24_1_100_1_nonmath.jinja | Hermes 2 Pro |
|
||||
| knifeayumu-Cydonia-v1.3-Magnum-v4-22B.jinja | Mistral Nemo |
|
||||
| langgptai-qwen1.5-7b-chat-sa-v0.1.jinja | Generic |
|
||||
| lightblue-DeepSeek-R1-Distill-Qwen-7B-Japanese.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| mattshumer-Reflection-Llama-3.1-70B.jinja | Generic |
|
||||
| meetkai-functionary-medium-v3.1.jinja | Functionary v3.1 Llama 3.1 |
|
||||
| meetkai-functionary-medium-v3.2.jinja | Functionary v3.2 |
|
||||
| meta-llama-Llama-2-7b-chat-hf.jinja | Generic |
|
||||
| meta-llama-Llama-3.1-8B-Instruct.jinja | Llama 3.x |
|
||||
| meta-llama-Llama-3.2-11B-Vision-Instruct.jinja | Llama 3.x |
|
||||
| meta-llama-Llama-3.2-1B-Instruct.jinja | Llama 3.x |
|
||||
| meta-llama-Llama-3.2-3B-Instruct.jinja | Llama 3.x |
|
||||
| meta-llama-Llama-3.3-70B-Instruct.jinja | Llama 3.x |
|
||||
| meta-llama-Meta-Llama-3-8B-Instruct.jinja | Generic |
|
||||
| meta-llama-Meta-Llama-3.1-8B-Instruct.jinja | Llama 3.x |
|
||||
| microsoft-Phi-3-medium-4k-instruct.jinja | Generic |
|
||||
| microsoft-Phi-3-mini-4k-instruct.jinja | Generic |
|
||||
| microsoft-Phi-3-small-8k-instruct.jinja | Generic |
|
||||
| microsoft-Phi-3.5-mini-instruct.jinja | Generic |
|
||||
| microsoft-Phi-3.5-vision-instruct.jinja | Generic |
|
||||
| microsoft-phi-4.jinja | Generic |
|
||||
| migtissera-Tess-3-Mistral-Nemo-12B.jinja | Generic |
|
||||
| ministral-Ministral-3b-instruct.jinja | Generic |
|
||||
| mistralai-Codestral-22B-v0.1.jinja | Generic |
|
||||
| mistralai-Mistral-7B-Instruct-v0.1.jinja | Generic |
|
||||
| mistralai-Mistral-7B-Instruct-v0.2.jinja | Generic |
|
||||
| mistralai-Mistral-7B-Instruct-v0.3.jinja | Mistral Nemo |
|
||||
| mistralai-Mistral-Large-Instruct-2407.jinja | Mistral Nemo |
|
||||
| mistralai-Mistral-Large-Instruct-2411.jinja | Generic |
|
||||
| mistralai-Mistral-Nemo-Instruct-2407.jinja | Mistral Nemo |
|
||||
| mistralai-Mistral-Small-24B-Instruct-2501.jinja | Generic |
|
||||
| mistralai-Mixtral-8x7B-Instruct-v0.1.jinja | Generic |
|
||||
| mkurman-Qwen2.5-14B-DeepSeek-R1-1M.jinja | Hermes 2 Pro |
|
||||
| mlabonne-AlphaMonarch-7B.jinja | Generic |
|
||||
| mlx-community-Josiefied-Qwen2.5-0.5B-Instruct-abliterated-v1-float32.jinja | Hermes 2 Pro |
|
||||
| mlx-community-Qwen2.5-VL-7B-Instruct-8bit.jinja | Hermes 2 Pro |
|
||||
| mobiuslabsgmbh-DeepSeek-R1-ReDistill-Qwen-1.5B-v1.1.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| netcat420-MFANNv0.20.jinja | Generic |
|
||||
| netcat420-MFANNv0.24.jinja | Generic |
|
||||
| netease-youdao-Confucius-o1-14B.jinja | Hermes 2 Pro |
|
||||
| nvidia-AceMath-7B-RM.jinja | Hermes 2 Pro |
|
||||
| nvidia-Eagle2-1B.jinja | Hermes 2 Pro |
|
||||
| nvidia-Eagle2-9B.jinja | Hermes 2 Pro |
|
||||
| nvidia-Llama-3.1-Nemotron-70B-Instruct-HF.jinja | Llama 3.x |
|
||||
| onnx-community-DeepSeek-R1-Distill-Qwen-1.5B-ONNX.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| open-thoughts-OpenThinker-7B.jinja | Hermes 2 Pro |
|
||||
| openchat-openchat-3.5-0106.jinja | Generic |
|
||||
| pankajmathur-orca_mini_v6_8b.jinja | Generic |
|
||||
| princeton-nlp-Mistral-7B-Base-SFT-RDPO.jinja | Generic |
|
||||
| princeton-nlp-Mistral-7B-Instruct-DPO.jinja | Generic |
|
||||
| princeton-nlp-Mistral-7B-Instruct-RDPO.jinja | Generic |
|
||||
| prithivMLmods-Bellatrix-Tiny-1.5B-R1.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Bellatrix-Tiny-1B-R1.jinja | Llama 3.x |
|
||||
| prithivMLmods-Bellatrix-Tiny-1B-v3.jinja | Generic |
|
||||
| prithivMLmods-Bellatrix-Tiny-3B-R1.jinja | Llama 3.x |
|
||||
| prithivMLmods-Blaze-14B-xElite.jinja | Generic |
|
||||
| prithivMLmods-Calcium-Opus-14B-Elite2-R1.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Calme-Ties-78B.jinja | Generic |
|
||||
| prithivMLmods-Calme-Ties2-78B.jinja | Generic |
|
||||
| prithivMLmods-Calme-Ties3-78B.jinja | Generic |
|
||||
| prithivMLmods-ChemQwen2-vL.jinja | Generic |
|
||||
| prithivMLmods-GWQ2b.jinja | Generic |
|
||||
| prithivMLmods-LatexMind-2B-Codec.jinja | Generic |
|
||||
| prithivMLmods-Llama-3.2-6B-AlgoCode.jinja | Llama 3.x |
|
||||
| prithivMLmods-Megatron-Opus-14B-Exp.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Megatron-Opus-14B-Stock.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Megatron-Opus-7B-Exp.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Omni-Reasoner-Merged.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Omni-Reasoner4-Merged.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Primal-Opus-14B-Optimus-v1.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-QwQ-Math-IO-500M.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Qwen-7B-Distill-Reasoner.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| prithivMLmods-Qwen2.5-1.5B-DeepSeek-R1-Instruct.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Qwen2.5-14B-DeepSeek-R1-1M.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Qwen2.5-32B-DeepSeek-R1-Instruct.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Qwen2.5-7B-DeepSeek-R1-1M.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Triangulum-v2-10B.jinja | Hermes 2 Pro |
|
||||
| qingy2024-Falcon3-2x10B-MoE-Instruct.jinja | Hermes 2 Pro |
|
||||
| rubenroy-Zurich-14B-GCv2-5m.jinja | Hermes 2 Pro |
|
||||
| rubenroy-Zurich-7B-GCv2-5m.jinja | Hermes 2 Pro |
|
||||
| silma-ai-SILMA-Kashif-2B-Instruct-v1.0.jinja | Generic |
|
||||
| simplescaling-s1-32B.jinja | Hermes 2 Pro |
|
||||
| sometimesanotion-Lamarck-14B-v0.7.jinja | Hermes 2 Pro |
|
||||
| sonthenguyen-zephyr-sft-bnb-4bit-DPO-mtbr-180steps.jinja | Generic |
|
||||
| sthenno-tempesthenno-icy-0130.jinja | Generic |
|
||||
| sumink-qwft.jinja | Hermes 2 Pro |
|
||||
| teknium-OpenHermes-2.5-Mistral-7B.jinja | Generic |
|
||||
| thirdeyeai-elevate360m.jinja | Generic |
|
||||
| tiiuae-Falcon3-10B-Instruct.jinja | Hermes 2 Pro |
|
||||
| unsloth-DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| unsloth-DeepSeek-R1-Distill-Llama-8B.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| unsloth-DeepSeek-R1.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| unsloth-Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit.jinja | Generic |
|
||||
| upstage-solar-pro-preview-instruct.jinja | Generic |
|
||||
| whyhow-ai-PatientSeek.jinja | Generic |
|
||||
| xwen-team-Xwen-72B-Chat.jinja | Hermes 2 Pro |
|
||||
| xwen-team-Xwen-7B-Chat.jinja | Hermes 2 Pro |
|
||||
|
||||
This table can be generated with:
|
||||
|
||||
```bash
|
||||
./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
@@ -1202,11 +1394,20 @@ curl http://localhost:8080/v1/chat/completions \
|
||||
|
||||
```shell
|
||||
# Native support:
|
||||
|
||||
llama-server --jinja -fa -hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M
|
||||
llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q6_K_L
|
||||
llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M
|
||||
llama-server --jinja -fa -hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M
|
||||
|
||||
# Native support for DeepSeek R1 works best w/ our own template (official template buggy)
|
||||
|
||||
llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q6_K_L \
|
||||
--chat-template-file models/templates/llama-cpp-deepseek-r1.jinja
|
||||
|
||||
llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-32B-GGUF:Q4_K_M \
|
||||
--chat-template-file models/templates/llama-cpp-deepseek-r1.jinja
|
||||
|
||||
# Native support requires the right template for these GGUFs:
|
||||
|
||||
llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \
|
||||
@@ -1236,17 +1437,17 @@ curl http://localhost:8080/v1/chat/completions \
|
||||
{
|
||||
"type":"function",
|
||||
"function":{
|
||||
"name":"get_current_weather",
|
||||
"description":"Get the current weather in a given location",
|
||||
"name":"python",
|
||||
"description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
|
||||
"parameters":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"location":{
|
||||
"code":{
|
||||
"type":"string",
|
||||
"description":"The city and state, e.g. San Francisco, CA"
|
||||
"description":"The code to run in the ipython interpreter."
|
||||
}
|
||||
},
|
||||
"required":["location"]
|
||||
"required":["code"]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1254,7 +1455,7 @@ curl http://localhost:8080/v1/chat/completions \
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Istanbul?."
|
||||
"content": "Print a hello world message with python."
|
||||
}
|
||||
]
|
||||
}'
|
||||
|
||||
Binary file not shown.
@@ -173,6 +173,7 @@ struct slot_params {
|
||||
{"grammar_trigger_words", grammar_trigger_words},
|
||||
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
|
||||
{"preserved_tokens", sampling.preserved_tokens},
|
||||
{"chat_format", common_chat_format_name(oaicompat_chat_format)},
|
||||
{"samplers", samplers},
|
||||
{"speculative.n_max", speculative.n_max},
|
||||
{"speculative.n_min", speculative.n_min},
|
||||
@@ -334,24 +335,24 @@ struct server_task {
|
||||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||
try {
|
||||
auto schema = json_value(data, "json_schema", json::object());
|
||||
LOG_DBG("JSON schema: %s\n", schema.dump(2).c_str());
|
||||
SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str());
|
||||
params.sampling.grammar = json_schema_to_grammar(schema);
|
||||
LOG_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
|
||||
SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
|
||||
} catch (const std::exception & e) {
|
||||
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
||||
}
|
||||
} else {
|
||||
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
||||
LOG_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
|
||||
SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
|
||||
params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
|
||||
LOG_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
|
||||
SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
|
||||
}
|
||||
|
||||
{
|
||||
auto it = data.find("chat_format");
|
||||
if (it != data.end()) {
|
||||
params.oaicompat_chat_format = static_cast<common_chat_format>(it->get<int>());
|
||||
LOG_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
|
||||
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
|
||||
} else {
|
||||
params.oaicompat_chat_format = defaults.oaicompat_chat_format;
|
||||
}
|
||||
@@ -367,12 +368,12 @@ struct server_task {
|
||||
|
||||
auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
|
||||
SRV_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
|
||||
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
|
||||
params.sampling.preserved_tokens.insert(ids[0]);
|
||||
continue;
|
||||
}
|
||||
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
|
||||
SRV_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
|
||||
params.sampling.grammar_trigger_words.push_back(trigger);
|
||||
}
|
||||
}
|
||||
@@ -381,11 +382,11 @@ struct server_task {
|
||||
for (const auto & t : *preserved_tokens) {
|
||||
auto ids = common_tokenize(vocab, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
LOG_DBG("Preserved token: %d\n", ids[0]);
|
||||
SRV_DBG("Preserved token: %d\n", ids[0]);
|
||||
params.sampling.preserved_tokens.insert(ids[0]);
|
||||
} else {
|
||||
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
|
||||
LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
|
||||
SRV_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -717,16 +718,26 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
std::string finish_reason = "length";
|
||||
common_chat_msg msg;
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
LOG_DBG("Parsing chat message: %s\n", content.c_str());
|
||||
SRV_DBG("Parsing chat message: %s\n", content.c_str());
|
||||
msg = common_chat_parse(content, oaicompat_chat_format);
|
||||
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
|
||||
} else {
|
||||
msg.content = content;
|
||||
}
|
||||
|
||||
json tool_calls;
|
||||
json message {
|
||||
{"role", "assistant"},
|
||||
};
|
||||
if (!msg.reasoning_content.empty()) {
|
||||
message["reasoning_content"] = msg.reasoning_content;
|
||||
}
|
||||
if (msg.content.empty() && !msg.tool_calls.empty()) {
|
||||
message["content"] = json();
|
||||
} else {
|
||||
message["content"] = msg.content;
|
||||
}
|
||||
if (!msg.tool_calls.empty()) {
|
||||
tool_calls = json::array();
|
||||
auto tool_calls = json::array();
|
||||
for (const auto & tc : msg.tool_calls) {
|
||||
tool_calls.push_back({
|
||||
{"type", "function"},
|
||||
@@ -737,15 +748,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
{"id", tc.id},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
json message {
|
||||
{"content", msg.content},
|
||||
{"tool_calls", tool_calls},
|
||||
{"role", "assistant"},
|
||||
};
|
||||
if (!msg.tool_plan.empty()) {
|
||||
message["tool_plan"] = msg.tool_plan;
|
||||
message["tool_calls"] = tool_calls;
|
||||
}
|
||||
|
||||
json choice {
|
||||
@@ -1600,6 +1603,10 @@ struct server_queue {
|
||||
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (!running) {
|
||||
QUE_DBG("%s", "terminate\n");
|
||||
return;
|
||||
}
|
||||
if (queue_tasks.empty()) {
|
||||
lock.unlock();
|
||||
break;
|
||||
@@ -1620,11 +1627,11 @@ struct server_queue {
|
||||
QUE_DBG("%s", "waiting for new tasks\n");
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (!running) {
|
||||
QUE_DBG("%s", "terminate\n");
|
||||
return;
|
||||
}
|
||||
if (queue_tasks.empty()) {
|
||||
if (!running) {
|
||||
QUE_DBG("%s", "terminate\n");
|
||||
return;
|
||||
}
|
||||
condition_tasks.wait(lock, [&]{
|
||||
return (!queue_tasks.empty() || !running);
|
||||
});
|
||||
@@ -1885,7 +1892,7 @@ struct server_context {
|
||||
}
|
||||
|
||||
if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) {
|
||||
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
||||
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
||||
chat_templates = common_chat_templates_from_model(model, "chatml");
|
||||
} else {
|
||||
chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
|
||||
@@ -2069,8 +2076,8 @@ struct server_context {
|
||||
|
||||
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
||||
// Might be better to reject the request with a 400 ?
|
||||
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.params.n_predict, slot.n_predict);
|
||||
slot.params.n_predict = slot.n_predict;
|
||||
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
|
||||
}
|
||||
|
||||
if (slot.params.ignore_eos && has_eos_token) {
|
||||
@@ -2275,7 +2282,7 @@ struct server_context {
|
||||
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
|
||||
result.probs.push_back({
|
||||
cur_p->data[i].id,
|
||||
common_detokenize(ctx, {cur_p->data[i].id}, special),
|
||||
common_token_to_piece(ctx, cur_p->data[i].id, special),
|
||||
cur_p->data[i].p
|
||||
});
|
||||
}
|
||||
@@ -2297,7 +2304,7 @@ struct server_context {
|
||||
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
|
||||
result.probs.push_back({
|
||||
cur[i].id,
|
||||
common_detokenize(ctx, {cur[i].id}, special),
|
||||
common_token_to_piece(ctx, cur[i].id, special),
|
||||
cur[i].p
|
||||
});
|
||||
}
|
||||
@@ -3355,10 +3362,10 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
|
||||
|
||||
// reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
|
||||
|
||||
LOG_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
|
||||
SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
|
||||
|
||||
LOG_DBG("request: %s\n", req.body.c_str());
|
||||
LOG_DBG("response: %s\n", res.body.c_str());
|
||||
SRV_DBG("request: %s\n", req.body.c_str());
|
||||
SRV_DBG("response: %s\n", res.body.c_str());
|
||||
}
|
||||
|
||||
std::function<void(int)> shutdown_handler;
|
||||
@@ -3860,7 +3867,9 @@ int main(int argc, char ** argv) {
|
||||
|
||||
try {
|
||||
const auto & prompt = data.at("prompt");
|
||||
LOG_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
||||
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
||||
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
||||
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
|
||||
tasks.reserve(tokenized_prompts.size());
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
@@ -4054,7 +4063,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
auto body = json::parse(req.body);
|
||||
json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates);
|
||||
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
|
||||
|
||||
return handle_completions_impl(
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
@@ -4067,7 +4076,7 @@ int main(int argc, char ** argv) {
|
||||
// same with handle_chat_completions, but without inference part
|
||||
const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
auto body = json::parse(req.body);
|
||||
json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates);
|
||||
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
|
||||
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
|
||||
};
|
||||
|
||||
@@ -4376,6 +4385,9 @@ int main(int argc, char ** argv) {
|
||||
res.set_content("Error: gzip is not supported by this browser", "text/plain");
|
||||
} else {
|
||||
res.set_header("Content-Encoding", "gzip");
|
||||
// COEP and COOP headers, required by pyodide (python interpreter)
|
||||
res.set_header("Cross-Origin-Embedder-Policy", "require-corp");
|
||||
res.set_header("Cross-Origin-Opener-Policy", "same-origin");
|
||||
res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8");
|
||||
}
|
||||
return false;
|
||||
@@ -4425,6 +4437,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// clean up function, to be called before exit
|
||||
auto clean_up = [&svr]() {
|
||||
SRV_INF("%s: cleaning up before exit...\n", __func__);
|
||||
svr->stop();
|
||||
llama_backend_free();
|
||||
};
|
||||
@@ -4441,10 +4454,6 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
if (!was_bound) {
|
||||
//LOG_ERROR("couldn't bind HTTP server socket", {
|
||||
// {"hostname", params.hostname},
|
||||
// {"port", params.port},
|
||||
//});
|
||||
LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port);
|
||||
clean_up();
|
||||
return 1;
|
||||
@@ -4461,7 +4470,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
if (!ctx_server.load_model(params)) {
|
||||
clean_up();
|
||||
t.join();
|
||||
// t.join(); // FIXME: see below
|
||||
LOG_ERR("%s: exiting due to model loading error\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@@ -4485,13 +4494,10 @@ int main(int argc, char ** argv) {
|
||||
});
|
||||
|
||||
shutdown_handler = [&](int) {
|
||||
// this will unblock start_loop()
|
||||
ctx_server.queue_tasks.terminate();
|
||||
};
|
||||
|
||||
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
|
||||
|
||||
ctx_server.queue_tasks.start_loop();
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = signal_handler;
|
||||
@@ -4506,8 +4512,13 @@ int main(int argc, char ** argv) {
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
#endif
|
||||
|
||||
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
|
||||
|
||||
// this call blocks the main thread until queue_tasks.terminate() is called
|
||||
ctx_server.queue_tasks.start_loop();
|
||||
|
||||
clean_up();
|
||||
t.join();
|
||||
// t.join(); // FIXME: http thread may stuck if there is an on-going request. we don't need to care about this for now as the HTTP connection will already be closed at this point, but it's better to fix this
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -92,6 +92,7 @@ def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, a
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}'
|
||||
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
|
||||
assert expected_function_name == tool_call["function"]["name"]
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
@@ -155,11 +156,11 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
|
||||
# (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||
# (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
@@ -175,7 +176,7 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
# (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
# TODO: fix these
|
||||
# (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
# (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
@@ -214,6 +215,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}'
|
||||
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
|
||||
assert expected_function_name == tool_call["function"]["name"]
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
@@ -273,7 +275,6 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("hf_repo,template_override", [
|
||||
("bartowski/c4ai-command-r7b-12-2024-GGUF:Q4_K_M", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")),
|
||||
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
@@ -298,13 +299,16 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
|
||||
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")),
|
||||
|
||||
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
|
||||
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
|
||||
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||
|
||||
# ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
# ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None):
|
||||
def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
global server
|
||||
n_predict = 512
|
||||
server.n_slots = 1
|
||||
@@ -323,6 +327,7 @@ def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
|
||||
{"role": "user", "content": "What is the weather in Istanbul?"},
|
||||
],
|
||||
"tools": [WEATHER_TOOL],
|
||||
@@ -332,6 +337,7 @@ def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None)
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}'
|
||||
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"]
|
||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
|
||||
@@ -340,22 +346,166 @@ def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None)
|
||||
assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
|
||||
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||
(None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
|
||||
(None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
(None, 128, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(None, 128, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
|
||||
(None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
("^> 0.56$", 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
|
||||
# TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
|
||||
("^The y-coordinate [\\s\\S]*?\\*\\*0.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
("[\\s\\S]*?\\*\\*0\\.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
])
|
||||
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
global server
|
||||
# n_predict = 512
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192 * 2
|
||||
server.n_predict = n_predict
|
||||
server.model_hf_repo = hf_repo
|
||||
server.model_hf_file = None
|
||||
if isinstance(template_override, tuple):
|
||||
(template_hf_repo, template_variant) = template_override
|
||||
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things, and provide very concise answers. Do not explain your reasoning to the user. Provide any numerical values back to the user with at most two decimals."},
|
||||
{"role": "user", "content": "What's the y coordinate of a point on the unit sphere at angle 30 degrees?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_6789",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculate",
|
||||
"arguments": "{\"expression\":\"sin(30 * pi / 180)\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "calculate",
|
||||
"content": 0.55644242476,
|
||||
"tool_call_id": "call_6789"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type":"function",
|
||||
"function":{
|
||||
"name":"calculate",
|
||||
"description":"A calculator function that computes values of arithmetic expressions in the Python syntax",
|
||||
"parameters":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"expression":{
|
||||
"type":"string",
|
||||
"description":"An arithmetic expression to compute the value of (Python syntad, assuming all floats)"
|
||||
}
|
||||
},
|
||||
"required":["expression"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls is None, f'Expected no tool call in {choice["message"]}'
|
||||
content = choice["message"].get("content")
|
||||
assert content is not None, f'Expected content in {choice["message"]}'
|
||||
if result_override is not None:
|
||||
assert re.match(result_override, content), f'Expected {result_override}, got {content}'
|
||||
else:
|
||||
assert re.match('^[\\s\\S]*?The (y[ -])?coordinate [\\s\\S]*?is (approximately )?0\\.56\\b|^0\\.56$', content), \
|
||||
f'Expected something like "The y coordinate is 0.56.", got {content}'
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("n_predict,reasoning_format,expect_content,expect_reasoning_content,hf_repo,template_override", [
|
||||
(128, 'deepseek', "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(128, None, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
|
||||
(1024, 'deepseek', "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(1024, 'none', "<think>\n?I need[\\s\\S]*?</think>\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
|
||||
(1024, 'deepseek', "To find the sum of.*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
])
|
||||
def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
global server
|
||||
server.n_slots = 1
|
||||
server.reasoning_format = reasoning_format
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192 * 2
|
||||
server.n_predict = n_predict
|
||||
server.model_hf_repo = hf_repo
|
||||
server.model_hf_file = None
|
||||
if isinstance(template_override, tuple):
|
||||
(template_hf_repo, template_variant) = template_override
|
||||
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the sum of 102 and 7?"},
|
||||
]
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
|
||||
|
||||
content = choice["message"].get("content")
|
||||
if expect_content is None:
|
||||
assert content is None, f'Expected no content in {choice["message"]}'
|
||||
else:
|
||||
assert re.match(expect_content, content), f'Expected {expect_content}, got {content}'
|
||||
|
||||
reasoning_content = choice["message"].get("reasoning_content")
|
||||
if expect_reasoning_content is None:
|
||||
assert reasoning_content is None, f'Expected no reasoning content in {choice["message"]}'
|
||||
else:
|
||||
assert re.match(expect_reasoning_content, reasoning_content), f'Expected {expect_reasoning_content}, got {reasoning_content}'
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [
|
||||
(None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
# (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)),
|
||||
(None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
|
||||
|
||||
(None, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
(None, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
('{"code":"print("}', "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
(None, "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
|
||||
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
@@ -371,15 +521,13 @@ def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None)
|
||||
|
||||
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
|
||||
(None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||
|
||||
# (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
def test_hello_world(expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
global server
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = 128
|
||||
server.n_predict = 512 # High because of DeepSeek R1
|
||||
server.model_hf_repo = hf_repo
|
||||
server.model_hf_file = None
|
||||
if isinstance(template_override, tuple):
|
||||
@@ -406,6 +554,7 @@ def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo:
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}'
|
||||
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
if expected_arguments_override is not None:
|
||||
|
||||
@@ -78,6 +78,7 @@ class ServerProcess:
|
||||
draft_max: int | None = None
|
||||
no_webui: bool | None = None
|
||||
jinja: bool | None = None
|
||||
reasoning_format: Literal['deepseek', 'none'] | None = None
|
||||
chat_template: str | None = None
|
||||
chat_template_file: str | None = None
|
||||
|
||||
@@ -172,6 +173,8 @@ class ServerProcess:
|
||||
server_args.append("--no-webui")
|
||||
if self.jinja:
|
||||
server_args.append("--jinja")
|
||||
if self.reasoning_format is not None:
|
||||
server_args.extend(("--reasoning-format", self.reasoning_format))
|
||||
if self.chat_template:
|
||||
server_args.extend(["--chat-template", self.chat_template])
|
||||
if self.chat_template_file:
|
||||
|
||||
@@ -578,6 +578,7 @@ static json oaicompat_completion_params_parse(const json & body) {
|
||||
static json oaicompat_completion_params_parse(
|
||||
const json & body, /* openai api json semantics */
|
||||
bool use_jinja,
|
||||
common_reasoning_format reasoning_format,
|
||||
const common_chat_templates & chat_templates)
|
||||
{
|
||||
json llama_params;
|
||||
@@ -633,9 +634,10 @@ static json oaicompat_completion_params_parse(
|
||||
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
||||
}
|
||||
common_chat_inputs inputs;
|
||||
inputs.messages = body.at("messages");
|
||||
inputs.tools = tools;
|
||||
inputs.tool_choice = tool_choice;
|
||||
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
inputs.messages = body.at("messages");
|
||||
inputs.tools = tools;
|
||||
inputs.tool_choice = tool_choice;
|
||||
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
||||
if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
|
||||
LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
|
||||
|
||||
17
examples/server/webui/package-lock.json
generated
17
examples/server/webui/package-lock.json
generated
@@ -8,10 +8,12 @@
|
||||
"name": "webui",
|
||||
"version": "0.0.0",
|
||||
"dependencies": {
|
||||
"@heroicons/react": "^2.2.0",
|
||||
"@sec-ant/readable-stream": "^0.6.0",
|
||||
"@vscode/markdown-it-katex": "^1.1.1",
|
||||
"autoprefixer": "^10.4.20",
|
||||
"daisyui": "^4.12.14",
|
||||
"dexie": "^4.0.11",
|
||||
"highlight.js": "^11.10.0",
|
||||
"katex": "^0.16.15",
|
||||
"postcss": "^8.4.49",
|
||||
@@ -902,6 +904,15 @@
|
||||
"node": "^18.18.0 || ^20.9.0 || >=21.1.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@heroicons/react": {
|
||||
"version": "2.2.0",
|
||||
"resolved": "https://registry.npmjs.org/@heroicons/react/-/react-2.2.0.tgz",
|
||||
"integrity": "sha512-LMcepvRaS9LYHJGsF0zzmgKCUim/X3N/DQKc4jepAXJ7l8QxJ1PmxJzqplF2Z3FE4PqBAIGyJAQ/w4B5dsqbtQ==",
|
||||
"license": "MIT",
|
||||
"peerDependencies": {
|
||||
"react": ">= 16 || ^19.0.0-rc"
|
||||
}
|
||||
},
|
||||
"node_modules/@humanfs/core": {
|
||||
"version": "0.19.1",
|
||||
"resolved": "https://registry.npmjs.org/@humanfs/core/-/core-0.19.1.tgz",
|
||||
@@ -2328,6 +2339,12 @@
|
||||
"url": "https://github.com/sponsors/wooorm"
|
||||
}
|
||||
},
|
||||
"node_modules/dexie": {
|
||||
"version": "4.0.11",
|
||||
"resolved": "https://registry.npmjs.org/dexie/-/dexie-4.0.11.tgz",
|
||||
"integrity": "sha512-SOKO002EqlvBYYKQSew3iymBoN2EQ4BDw/3yprjh7kAfFzjBYkaMNa/pZvcA7HSWlcKSQb9XhPe3wKyQ0x4A8A==",
|
||||
"license": "Apache-2.0"
|
||||
},
|
||||
"node_modules/didyoumean": {
|
||||
"version": "1.2.2",
|
||||
"resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz",
|
||||
|
||||
@@ -11,10 +11,12 @@
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"@heroicons/react": "^2.2.0",
|
||||
"@sec-ant/readable-stream": "^0.6.0",
|
||||
"@vscode/markdown-it-katex": "^1.1.1",
|
||||
"autoprefixer": "^10.4.20",
|
||||
"daisyui": "^4.12.14",
|
||||
"dexie": "^4.0.11",
|
||||
"highlight.js": "^11.10.0",
|
||||
"katex": "^0.16.15",
|
||||
"postcss": "^8.4.49",
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { HashRouter, Outlet, Route, Routes } from 'react-router';
|
||||
import Header from './components/Header';
|
||||
import Sidebar from './components/Sidebar';
|
||||
import { AppContextProvider } from './utils/app.context';
|
||||
import { AppContextProvider, useAppContext } from './utils/app.context';
|
||||
import ChatScreen from './components/ChatScreen';
|
||||
import SettingDialog from './components/SettingDialog';
|
||||
|
||||
function App() {
|
||||
return (
|
||||
@@ -22,13 +23,23 @@ function App() {
|
||||
}
|
||||
|
||||
function AppLayout() {
|
||||
const { showSettings, setShowSettings } = useAppContext();
|
||||
return (
|
||||
<>
|
||||
<Sidebar />
|
||||
<div className="chat-screen drawer-content grow flex flex-col h-screen w-screen mx-auto px-4">
|
||||
<div
|
||||
className="drawer-content grow flex flex-col h-screen w-screen mx-auto px-4 overflow-auto"
|
||||
id="main-scroll"
|
||||
>
|
||||
<Header />
|
||||
<Outlet />
|
||||
</div>
|
||||
{
|
||||
<SettingDialog
|
||||
show={showSettings}
|
||||
onClose={() => setShowSettings(false)}
|
||||
/>
|
||||
}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ export const BASE_URL = new URL('.', document.baseURI).href
|
||||
|
||||
export const CONFIG_DEFAULT = {
|
||||
// Note: in order not to introduce breaking changes, please keep the same data type (number, string, etc) if you want to change the default value. Do not use null or undefined for default value.
|
||||
// Do not use nested objects, keep it single level. Prefix the key if you need to group them.
|
||||
apiKey: '',
|
||||
systemMessage: 'You are a helpful assistant.',
|
||||
showTokensPerSecond: false,
|
||||
@@ -36,6 +37,8 @@ export const CONFIG_DEFAULT = {
|
||||
dry_penalty_last_n: -1,
|
||||
max_tokens: -1,
|
||||
custom: '', // custom json-stringified object
|
||||
// experimental features
|
||||
pyIntepreterEnabled: false,
|
||||
};
|
||||
export const CONFIG_INFO: Record<string, string> = {
|
||||
apiKey: 'Set the API Key if you are using --api-key option for the server.',
|
||||
|
||||
195
examples/server/webui/src/components/CanvasPyInterpreter.tsx
Normal file
195
examples/server/webui/src/components/CanvasPyInterpreter.tsx
Normal file
@@ -0,0 +1,195 @@
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import { OpenInNewTab, XCloseButton } from '../utils/common';
|
||||
import { CanvasType } from '../utils/types';
|
||||
import { PlayIcon, StopIcon } from '@heroicons/react/24/outline';
|
||||
import StorageUtils from '../utils/storage';
|
||||
|
||||
const canInterrupt = typeof SharedArrayBuffer === 'function';
|
||||
|
||||
// adapted from https://pyodide.org/en/stable/usage/webworker.html
|
||||
const WORKER_CODE = `
|
||||
importScripts("https://cdn.jsdelivr.net/pyodide/v0.27.2/full/pyodide.js");
|
||||
|
||||
let stdOutAndErr = [];
|
||||
|
||||
let pyodideReadyPromise = loadPyodide({
|
||||
stdout: (data) => stdOutAndErr.push(data),
|
||||
stderr: (data) => stdOutAndErr.push(data),
|
||||
});
|
||||
|
||||
let alreadySetBuff = false;
|
||||
|
||||
self.onmessage = async (event) => {
|
||||
stdOutAndErr = [];
|
||||
|
||||
// make sure loading is done
|
||||
const pyodide = await pyodideReadyPromise;
|
||||
const { id, python, context, interruptBuffer } = event.data;
|
||||
|
||||
if (interruptBuffer && !alreadySetBuff) {
|
||||
pyodide.setInterruptBuffer(interruptBuffer);
|
||||
alreadySetBuff = true;
|
||||
}
|
||||
|
||||
// Now load any packages we need, run the code, and send the result back.
|
||||
await pyodide.loadPackagesFromImports(python);
|
||||
|
||||
// make a Python dictionary with the data from content
|
||||
const dict = pyodide.globals.get("dict");
|
||||
const globals = dict(Object.entries(context));
|
||||
try {
|
||||
self.postMessage({ id, running: true });
|
||||
// Execute the python code in this context
|
||||
const result = pyodide.runPython(python, { globals });
|
||||
self.postMessage({ result, id, stdOutAndErr });
|
||||
} catch (error) {
|
||||
self.postMessage({ error: error.message, id });
|
||||
}
|
||||
interruptBuffer[0] = 0;
|
||||
};
|
||||
`;
|
||||
|
||||
let worker: Worker;
|
||||
const interruptBuffer = canInterrupt
|
||||
? new Uint8Array(new SharedArrayBuffer(1))
|
||||
: null;
|
||||
|
||||
const startWorker = () => {
|
||||
if (!worker) {
|
||||
worker = new Worker(
|
||||
URL.createObjectURL(new Blob([WORKER_CODE], { type: 'text/javascript' }))
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
if (StorageUtils.getConfig().pyIntepreterEnabled) {
|
||||
startWorker();
|
||||
}
|
||||
|
||||
const runCodeInWorker = (
|
||||
pyCode: string,
|
||||
callbackRunning: () => void
|
||||
): {
|
||||
donePromise: Promise<string>;
|
||||
interrupt: () => void;
|
||||
} => {
|
||||
startWorker();
|
||||
const id = Math.random() * 1e8;
|
||||
const context = {};
|
||||
if (interruptBuffer) {
|
||||
interruptBuffer[0] = 0;
|
||||
}
|
||||
|
||||
const donePromise = new Promise<string>((resolve) => {
|
||||
worker.onmessage = (event) => {
|
||||
const { error, stdOutAndErr, running } = event.data;
|
||||
if (id !== event.data.id) return;
|
||||
if (running) {
|
||||
callbackRunning();
|
||||
return;
|
||||
} else if (error) {
|
||||
resolve(error.toString());
|
||||
} else {
|
||||
resolve(stdOutAndErr.join('\n'));
|
||||
}
|
||||
};
|
||||
worker.postMessage({ id, python: pyCode, context, interruptBuffer });
|
||||
});
|
||||
|
||||
const interrupt = () => {
|
||||
console.log('Interrupting...');
|
||||
console.trace();
|
||||
if (interruptBuffer) {
|
||||
interruptBuffer[0] = 2;
|
||||
}
|
||||
};
|
||||
|
||||
return { donePromise, interrupt };
|
||||
};
|
||||
|
||||
export default function CanvasPyInterpreter() {
|
||||
const { canvasData, setCanvasData } = useAppContext();
|
||||
|
||||
const [code, setCode] = useState(canvasData?.content ?? ''); // copy to avoid direct mutation
|
||||
const [running, setRunning] = useState(false);
|
||||
const [output, setOutput] = useState('');
|
||||
const [interruptFn, setInterruptFn] = useState<() => void>();
|
||||
const [showStopBtn, setShowStopBtn] = useState(false);
|
||||
|
||||
const runCode = async (pycode: string) => {
|
||||
interruptFn?.();
|
||||
setRunning(true);
|
||||
setOutput('Loading Pyodide...');
|
||||
const { donePromise, interrupt } = runCodeInWorker(pycode, () => {
|
||||
setOutput('Running...');
|
||||
setShowStopBtn(canInterrupt);
|
||||
});
|
||||
setInterruptFn(() => interrupt);
|
||||
const out = await donePromise;
|
||||
setOutput(out);
|
||||
setRunning(false);
|
||||
setShowStopBtn(false);
|
||||
};
|
||||
|
||||
// run code on mount
|
||||
useEffect(() => {
|
||||
setCode(canvasData?.content ?? '');
|
||||
runCode(canvasData?.content ?? '');
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [canvasData?.content]);
|
||||
|
||||
if (canvasData?.type !== CanvasType.PY_INTERPRETER) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="card bg-base-200 w-full h-full shadow-xl">
|
||||
<div className="card-body">
|
||||
<div className="flex justify-between items-center mb-4">
|
||||
<span className="text-lg font-bold">Python Interpreter</span>
|
||||
<XCloseButton
|
||||
className="bg-base-100"
|
||||
onClick={() => setCanvasData(null)}
|
||||
/>
|
||||
</div>
|
||||
<div className="grid grid-rows-3 gap-4 h-full">
|
||||
<textarea
|
||||
className="textarea textarea-bordered w-full h-full font-mono"
|
||||
value={code}
|
||||
onChange={(e) => setCode(e.target.value)}
|
||||
></textarea>
|
||||
<div className="font-mono flex flex-col row-span-2">
|
||||
<div className="flex items-center mb-2">
|
||||
<button
|
||||
className="btn btn-sm bg-base-100"
|
||||
onClick={() => runCode(code)}
|
||||
disabled={running}
|
||||
>
|
||||
<PlayIcon className="h-6 w-6" /> Run
|
||||
</button>
|
||||
{showStopBtn && (
|
||||
<button
|
||||
className="btn btn-sm bg-base-100 ml-2"
|
||||
onClick={() => interruptFn?.()}
|
||||
>
|
||||
<StopIcon className="h-6 w-6" /> Stop
|
||||
</button>
|
||||
)}
|
||||
<span className="grow text-right text-xs">
|
||||
<OpenInNewTab href="https://github.com/ggerganov/llama.cpp/issues/11762">
|
||||
Report a bug
|
||||
</OpenInNewTab>
|
||||
</span>
|
||||
</div>
|
||||
<textarea
|
||||
className="textarea textarea-bordered h-full dark-color"
|
||||
value={output}
|
||||
readOnly
|
||||
></textarea>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -3,6 +3,7 @@ import { useAppContext } from '../utils/app.context';
|
||||
import { Message, PendingMessage } from '../utils/types';
|
||||
import { classNames } from '../utils/misc';
|
||||
import MarkdownDisplay, { CopyButton } from './MarkdownDisplay';
|
||||
import { ChevronLeftIcon, ChevronRightIcon } from '@heroicons/react/24/outline';
|
||||
|
||||
interface SplitMessage {
|
||||
content: PendingMessage['content'];
|
||||
@@ -12,17 +13,24 @@ interface SplitMessage {
|
||||
|
||||
export default function ChatMessage({
|
||||
msg,
|
||||
siblingLeafNodeIds,
|
||||
siblingCurrIdx,
|
||||
id,
|
||||
scrollToBottom,
|
||||
onRegenerateMessage,
|
||||
onEditMessage,
|
||||
onChangeSibling,
|
||||
isPending,
|
||||
}: {
|
||||
msg: Message | PendingMessage;
|
||||
siblingLeafNodeIds: Message['id'][];
|
||||
siblingCurrIdx: number;
|
||||
id?: string;
|
||||
scrollToBottom: (requiresNearBottom: boolean) => void;
|
||||
onRegenerateMessage(msg: Message): void;
|
||||
onEditMessage(msg: Message, content: string): void;
|
||||
onChangeSibling(sibling: Message['id']): void;
|
||||
isPending?: boolean;
|
||||
}) {
|
||||
const { viewingConversation, replaceMessageAndGenerate, config } =
|
||||
useAppContext();
|
||||
const { viewingChat, config } = useAppContext();
|
||||
const [editingContent, setEditingContent] = useState<string | null>(null);
|
||||
const timings = useMemo(
|
||||
() =>
|
||||
@@ -37,6 +45,8 @@ export default function ChatMessage({
|
||||
: null,
|
||||
[msg.timings]
|
||||
);
|
||||
const nextSibling = siblingLeafNodeIds[siblingCurrIdx + 1];
|
||||
const prevSibling = siblingLeafNodeIds[siblingCurrIdx - 1];
|
||||
|
||||
// for reasoning model, we split the message into content and thought
|
||||
// TODO: implement this as remark/rehype plugin in the future
|
||||
@@ -64,13 +74,7 @@ export default function ChatMessage({
|
||||
return { content: actualContent, thought, isThinking };
|
||||
}, [msg]);
|
||||
|
||||
if (!viewingConversation) return null;
|
||||
|
||||
const regenerate = async () => {
|
||||
replaceMessageAndGenerate(viewingConversation.id, msg.id, undefined, () =>
|
||||
scrollToBottom(true)
|
||||
);
|
||||
};
|
||||
if (!viewingChat) return null;
|
||||
|
||||
return (
|
||||
<div className="group" id={id}>
|
||||
@@ -92,7 +96,7 @@ export default function ChatMessage({
|
||||
<>
|
||||
<textarea
|
||||
dir="auto"
|
||||
className="textarea textarea-bordered bg-base-100 text-base-content w-[calc(90vw-8em)] lg:w-96"
|
||||
className="textarea textarea-bordered bg-base-100 text-base-content max-w-2xl w-[calc(90vw-8em)] h-24"
|
||||
value={editingContent}
|
||||
onChange={(e) => setEditingContent(e.target.value)}
|
||||
></textarea>
|
||||
@@ -105,13 +109,12 @@ export default function ChatMessage({
|
||||
</button>
|
||||
<button
|
||||
className="btn mt-2"
|
||||
onClick={() =>
|
||||
replaceMessageAndGenerate(
|
||||
viewingConversation.id,
|
||||
msg.id,
|
||||
editingContent
|
||||
)
|
||||
}
|
||||
onClick={() => {
|
||||
if (msg.content !== null) {
|
||||
setEditingContent(null);
|
||||
onEditMessage(msg as Message, editingContent);
|
||||
}
|
||||
}}
|
||||
>
|
||||
Submit
|
||||
</button>
|
||||
@@ -149,11 +152,17 @@ export default function ChatMessage({
|
||||
)}
|
||||
</summary>
|
||||
<div className="collapse-content">
|
||||
<MarkdownDisplay content={thought} />
|
||||
<MarkdownDisplay
|
||||
content={thought}
|
||||
isGenerating={isPending}
|
||||
/>
|
||||
</div>
|
||||
</details>
|
||||
)}
|
||||
<MarkdownDisplay content={content} />
|
||||
<MarkdownDisplay
|
||||
content={content}
|
||||
isGenerating={isPending}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
@@ -190,10 +199,35 @@ export default function ChatMessage({
|
||||
{msg.content !== null && (
|
||||
<div
|
||||
className={classNames({
|
||||
'mx-4 mt-2 mb-2': true,
|
||||
'text-right': msg.role === 'user',
|
||||
'flex items-center gap-2 mx-4 mt-2 mb-2': true,
|
||||
'flex-row-reverse': msg.role === 'user',
|
||||
})}
|
||||
>
|
||||
{siblingLeafNodeIds && siblingLeafNodeIds.length > 1 && (
|
||||
<div className="flex gap-1 items-center opacity-60 text-sm">
|
||||
<button
|
||||
className={classNames({
|
||||
'btn btn-sm btn-ghost p-1': true,
|
||||
'opacity-20': !prevSibling,
|
||||
})}
|
||||
onClick={() => prevSibling && onChangeSibling(prevSibling)}
|
||||
>
|
||||
<ChevronLeftIcon className="h-4 w-4" />
|
||||
</button>
|
||||
<span>
|
||||
{siblingCurrIdx + 1} / {siblingLeafNodeIds.length}
|
||||
</span>
|
||||
<button
|
||||
className={classNames({
|
||||
'btn btn-sm btn-ghost p-1': true,
|
||||
'opacity-20': !nextSibling,
|
||||
})}
|
||||
onClick={() => nextSibling && onChangeSibling(nextSibling)}
|
||||
>
|
||||
<ChevronRightIcon className="h-4 w-4" />
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{/* user message */}
|
||||
{msg.role === 'user' && (
|
||||
<button
|
||||
@@ -210,18 +244,22 @@ export default function ChatMessage({
|
||||
{!isPending && (
|
||||
<button
|
||||
className="badge btn-mini show-on-hover mr-2"
|
||||
onClick={regenerate}
|
||||
onClick={() => {
|
||||
if (msg.content !== null) {
|
||||
onRegenerateMessage(msg as Message);
|
||||
}
|
||||
}}
|
||||
disabled={msg.content === null}
|
||||
>
|
||||
🔄 Regenerate
|
||||
</button>
|
||||
)}
|
||||
<CopyButton
|
||||
className="badge btn-mini show-on-hover mr-2"
|
||||
content={msg.content}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
<CopyButton
|
||||
className="badge btn-mini show-on-hover mr-2"
|
||||
content={msg.content}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -1,123 +1,243 @@
|
||||
import { useEffect, useRef, useState } from 'react';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import StorageUtils from '../utils/storage';
|
||||
import { useNavigate } from 'react-router';
|
||||
import { useEffect, useMemo, useState } from 'react';
|
||||
import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
|
||||
import ChatMessage from './ChatMessage';
|
||||
import { PendingMessage } from '../utils/types';
|
||||
import { CanvasType, Message, PendingMessage } from '../utils/types';
|
||||
import { classNames, throttle } from '../utils/misc';
|
||||
import CanvasPyInterpreter from './CanvasPyInterpreter';
|
||||
import StorageUtils from '../utils/storage';
|
||||
|
||||
/**
|
||||
* A message display is a message node with additional information for rendering.
|
||||
* For example, siblings of the message node are stored as their last node (aka leaf node).
|
||||
*/
|
||||
export interface MessageDisplay {
|
||||
msg: Message | PendingMessage;
|
||||
siblingLeafNodeIds: Message['id'][];
|
||||
siblingCurrIdx: number;
|
||||
isPending?: boolean;
|
||||
}
|
||||
|
||||
function getListMessageDisplay(
|
||||
msgs: Readonly<Message[]>,
|
||||
leafNodeId: Message['id']
|
||||
): MessageDisplay[] {
|
||||
const currNodes = StorageUtils.filterByLeafNodeId(msgs, leafNodeId, true);
|
||||
const res: MessageDisplay[] = [];
|
||||
const nodeMap = new Map<Message['id'], Message>();
|
||||
for (const msg of msgs) {
|
||||
nodeMap.set(msg.id, msg);
|
||||
}
|
||||
// find leaf node from a message node
|
||||
const findLeafNode = (msgId: Message['id']): Message['id'] => {
|
||||
let currNode: Message | undefined = nodeMap.get(msgId);
|
||||
while (currNode) {
|
||||
if (currNode.children.length === 0) break;
|
||||
currNode = nodeMap.get(currNode.children.at(-1) ?? -1);
|
||||
}
|
||||
return currNode?.id ?? -1;
|
||||
};
|
||||
// traverse the current nodes
|
||||
for (const msg of currNodes) {
|
||||
const parentNode = nodeMap.get(msg.parent ?? -1);
|
||||
if (!parentNode) continue;
|
||||
const siblings = parentNode.children;
|
||||
if (msg.type !== 'root') {
|
||||
res.push({
|
||||
msg,
|
||||
siblingLeafNodeIds: siblings.map(findLeafNode),
|
||||
siblingCurrIdx: siblings.indexOf(msg.id),
|
||||
});
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
const scrollToBottom = throttle(
|
||||
(requiresNearBottom: boolean, delay: number = 80) => {
|
||||
const mainScrollElem = document.getElementById('main-scroll');
|
||||
if (!mainScrollElem) return;
|
||||
const spaceToBottom =
|
||||
mainScrollElem.scrollHeight -
|
||||
mainScrollElem.scrollTop -
|
||||
mainScrollElem.clientHeight;
|
||||
if (!requiresNearBottom || spaceToBottom < 50) {
|
||||
setTimeout(
|
||||
() => mainScrollElem.scrollTo({ top: mainScrollElem.scrollHeight }),
|
||||
delay
|
||||
);
|
||||
}
|
||||
},
|
||||
80
|
||||
);
|
||||
|
||||
export default function ChatScreen() {
|
||||
const {
|
||||
viewingConversation,
|
||||
viewingChat,
|
||||
sendMessage,
|
||||
isGenerating,
|
||||
stopGenerating,
|
||||
pendingMessages,
|
||||
canvasData,
|
||||
replaceMessageAndGenerate,
|
||||
} = useAppContext();
|
||||
const [inputMsg, setInputMsg] = useState('');
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const navigate = useNavigate();
|
||||
|
||||
const currConvId = viewingConversation?.id ?? '';
|
||||
const pendingMsg: PendingMessage | undefined = pendingMessages[currConvId];
|
||||
// keep track of leaf node for rendering
|
||||
const [currNodeId, setCurrNodeId] = useState<number>(-1);
|
||||
const messages: MessageDisplay[] = useMemo(() => {
|
||||
if (!viewingChat) return [];
|
||||
else return getListMessageDisplay(viewingChat.messages, currNodeId);
|
||||
}, [currNodeId, viewingChat]);
|
||||
|
||||
const scrollToBottom = (requiresNearBottom: boolean) => {
|
||||
if (!containerRef.current) return;
|
||||
const msgListElem = containerRef.current;
|
||||
const spaceToBottom =
|
||||
msgListElem.scrollHeight -
|
||||
msgListElem.scrollTop -
|
||||
msgListElem.clientHeight;
|
||||
if (!requiresNearBottom || spaceToBottom < 50) {
|
||||
setTimeout(
|
||||
() => msgListElem.scrollTo({ top: msgListElem.scrollHeight }),
|
||||
1
|
||||
);
|
||||
const currConvId = viewingChat?.conv.id ?? null;
|
||||
const pendingMsg: PendingMessage | undefined =
|
||||
pendingMessages[currConvId ?? ''];
|
||||
|
||||
useEffect(() => {
|
||||
// reset to latest node when conversation changes
|
||||
setCurrNodeId(-1);
|
||||
// scroll to bottom when conversation changes
|
||||
scrollToBottom(false, 1);
|
||||
}, [currConvId]);
|
||||
|
||||
const onChunk: CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => {
|
||||
if (currLeafNodeId) {
|
||||
setCurrNodeId(currLeafNodeId);
|
||||
}
|
||||
scrollToBottom(true);
|
||||
};
|
||||
|
||||
// scroll to bottom when conversation changes
|
||||
useEffect(() => {
|
||||
scrollToBottom(false);
|
||||
}, [viewingConversation?.id]);
|
||||
|
||||
const sendNewMessage = async () => {
|
||||
if (inputMsg.trim().length === 0 || isGenerating(currConvId)) return;
|
||||
const convId = viewingConversation?.id ?? StorageUtils.getNewConvId();
|
||||
if (inputMsg.trim().length === 0 || isGenerating(currConvId ?? '')) return;
|
||||
const lastInpMsg = inputMsg;
|
||||
setInputMsg('');
|
||||
if (!viewingConversation) {
|
||||
// if user is creating a new conversation, redirect to the new conversation
|
||||
navigate(`/chat/${convId}`);
|
||||
}
|
||||
scrollToBottom(false);
|
||||
// auto scroll as message is being generated
|
||||
const onChunk = () => scrollToBottom(true);
|
||||
if (!(await sendMessage(convId, inputMsg, onChunk))) {
|
||||
setCurrNodeId(-1);
|
||||
// get the last message node
|
||||
const lastMsgNodeId = messages.at(-1)?.msg.id ?? null;
|
||||
if (!(await sendMessage(currConvId, lastMsgNodeId, inputMsg, onChunk))) {
|
||||
// restore the input message if failed
|
||||
setInputMsg(lastInpMsg);
|
||||
}
|
||||
};
|
||||
|
||||
const handleEditMessage = async (msg: Message, content: string) => {
|
||||
if (!viewingChat) return;
|
||||
setCurrNodeId(msg.id);
|
||||
scrollToBottom(false);
|
||||
await replaceMessageAndGenerate(
|
||||
viewingChat.conv.id,
|
||||
msg.parent,
|
||||
content,
|
||||
onChunk
|
||||
);
|
||||
setCurrNodeId(-1);
|
||||
scrollToBottom(false);
|
||||
};
|
||||
|
||||
const handleRegenerateMessage = async (msg: Message) => {
|
||||
if (!viewingChat) return;
|
||||
setCurrNodeId(msg.parent);
|
||||
scrollToBottom(false);
|
||||
await replaceMessageAndGenerate(
|
||||
viewingChat.conv.id,
|
||||
msg.parent,
|
||||
null,
|
||||
onChunk
|
||||
);
|
||||
setCurrNodeId(-1);
|
||||
scrollToBottom(false);
|
||||
};
|
||||
|
||||
const hasCanvas = !!canvasData;
|
||||
|
||||
// due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg)
|
||||
const pendingMsgDisplay: MessageDisplay[] =
|
||||
pendingMsg && messages.at(-1)?.msg.id !== pendingMsg.id
|
||||
? [
|
||||
{
|
||||
msg: pendingMsg,
|
||||
siblingLeafNodeIds: [],
|
||||
siblingCurrIdx: 0,
|
||||
isPending: true,
|
||||
},
|
||||
]
|
||||
: [];
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* chat messages */}
|
||||
<div
|
||||
className={classNames({
|
||||
'grid lg:gap-8 grow transition-[300ms]': true,
|
||||
'grid-cols-[1fr_0fr] lg:grid-cols-[1fr_1fr]': hasCanvas, // adapted for mobile
|
||||
'grid-cols-[1fr_0fr]': !hasCanvas,
|
||||
})}
|
||||
>
|
||||
<div
|
||||
id="messages-list"
|
||||
className="flex flex-col grow overflow-y-auto"
|
||||
ref={containerRef}
|
||||
className={classNames({
|
||||
'flex flex-col w-full max-w-[900px] mx-auto': true,
|
||||
'hidden lg:flex': hasCanvas, // adapted for mobile
|
||||
flex: !hasCanvas,
|
||||
})}
|
||||
>
|
||||
<div className="mt-auto flex justify-center">
|
||||
{/* placeholder to shift the message to the bottom */}
|
||||
{viewingConversation ? '' : 'Send a message to start'}
|
||||
{/* chat messages */}
|
||||
<div id="messages-list" className="grow">
|
||||
<div className="mt-auto flex justify-center">
|
||||
{/* placeholder to shift the message to the bottom */}
|
||||
{viewingChat ? '' : 'Send a message to start'}
|
||||
</div>
|
||||
{[...messages, ...pendingMsgDisplay].map((msg) => (
|
||||
<ChatMessage
|
||||
key={msg.msg.id}
|
||||
msg={msg.msg}
|
||||
siblingLeafNodeIds={msg.siblingLeafNodeIds}
|
||||
siblingCurrIdx={msg.siblingCurrIdx}
|
||||
onRegenerateMessage={handleRegenerateMessage}
|
||||
onEditMessage={handleEditMessage}
|
||||
onChangeSibling={setCurrNodeId}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
{viewingConversation?.messages.map((msg) => (
|
||||
<ChatMessage key={msg.id} msg={msg} scrollToBottom={scrollToBottom} />
|
||||
))}
|
||||
|
||||
{pendingMsg && (
|
||||
<ChatMessage
|
||||
msg={pendingMsg}
|
||||
scrollToBottom={scrollToBottom}
|
||||
isPending
|
||||
id="pending-msg"
|
||||
/>
|
||||
{/* chat input */}
|
||||
<div className="flex flex-row items-center pt-8 pb-6 sticky bottom-0 bg-base-100">
|
||||
<textarea
|
||||
className="textarea textarea-bordered w-full"
|
||||
placeholder="Type a message (Shift+Enter to add a new line)"
|
||||
value={inputMsg}
|
||||
onChange={(e) => setInputMsg(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter' && e.shiftKey) return;
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
sendNewMessage();
|
||||
}
|
||||
}}
|
||||
id="msg-input"
|
||||
dir="auto"
|
||||
></textarea>
|
||||
{isGenerating(currConvId ?? '') ? (
|
||||
<button
|
||||
className="btn btn-neutral ml-2"
|
||||
onClick={() => stopGenerating(currConvId ?? '')}
|
||||
>
|
||||
Stop
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
className="btn btn-primary ml-2"
|
||||
onClick={sendNewMessage}
|
||||
disabled={inputMsg.trim().length === 0}
|
||||
>
|
||||
Send
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div className="w-full sticky top-[7em] h-[calc(100vh-9em)]">
|
||||
{canvasData?.type === CanvasType.PY_INTERPRETER && (
|
||||
<CanvasPyInterpreter />
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* chat input */}
|
||||
<div className="flex flex-row items-center mt-8 mb-6">
|
||||
<textarea
|
||||
className="textarea textarea-bordered w-full"
|
||||
placeholder="Type a message (Shift+Enter to add a new line)"
|
||||
value={inputMsg}
|
||||
onChange={(e) => setInputMsg(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter' && e.shiftKey) return;
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
sendNewMessage();
|
||||
}
|
||||
}}
|
||||
id="msg-input"
|
||||
dir="auto"
|
||||
></textarea>
|
||||
{isGenerating(currConvId) ? (
|
||||
<button
|
||||
className="btn btn-neutral ml-2"
|
||||
onClick={() => stopGenerating(currConvId)}
|
||||
>
|
||||
Stop
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
className="btn btn-primary ml-2"
|
||||
onClick={sendNewMessage}
|
||||
disabled={inputMsg.trim().length === 0}
|
||||
>
|
||||
Send
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -5,12 +5,11 @@ import { classNames } from '../utils/misc';
|
||||
import daisyuiThemes from 'daisyui/src/theming/themes';
|
||||
import { THEMES } from '../Config';
|
||||
import { useNavigate } from 'react-router';
|
||||
import SettingDialog from './SettingDialog';
|
||||
|
||||
export default function Header() {
|
||||
const navigate = useNavigate();
|
||||
const [selectedTheme, setSelectedTheme] = useState(StorageUtils.getTheme());
|
||||
const [showSettingDialog, setShowSettingDialog] = useState(false);
|
||||
const { setShowSettings } = useAppContext();
|
||||
|
||||
const setTheme = (theme: string) => {
|
||||
StorageUtils.setTheme(theme);
|
||||
@@ -26,12 +25,12 @@ export default function Header() {
|
||||
);
|
||||
}, [selectedTheme]);
|
||||
|
||||
const { isGenerating, viewingConversation } = useAppContext();
|
||||
const isCurrConvGenerating = isGenerating(viewingConversation?.id ?? '');
|
||||
const { isGenerating, viewingChat } = useAppContext();
|
||||
const isCurrConvGenerating = isGenerating(viewingChat?.conv.id ?? '');
|
||||
|
||||
const removeConversation = () => {
|
||||
if (isCurrConvGenerating || !viewingConversation) return;
|
||||
const convId = viewingConversation.id;
|
||||
if (isCurrConvGenerating || !viewingChat) return;
|
||||
const convId = viewingChat?.conv.id;
|
||||
if (window.confirm('Are you sure to delete this conversation?')) {
|
||||
StorageUtils.remove(convId);
|
||||
navigate('/');
|
||||
@@ -39,9 +38,9 @@ export default function Header() {
|
||||
};
|
||||
|
||||
const downloadConversation = () => {
|
||||
if (isCurrConvGenerating || !viewingConversation) return;
|
||||
const convId = viewingConversation.id;
|
||||
const conversationJson = JSON.stringify(viewingConversation, null, 2);
|
||||
if (isCurrConvGenerating || !viewingChat) return;
|
||||
const convId = viewingChat?.conv.id;
|
||||
const conversationJson = JSON.stringify(viewingChat, null, 2);
|
||||
const blob = new Blob([conversationJson], { type: 'application/json' });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement('a');
|
||||
@@ -54,7 +53,7 @@ export default function Header() {
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-row items-center mt-6 mb-6">
|
||||
<div className="flex flex-row items-center pt-6 pb-6 sticky top-0 z-10 bg-base-100">
|
||||
{/* open sidebar button */}
|
||||
<label htmlFor="toggle-drawer" className="btn btn-ghost lg:hidden">
|
||||
<svg
|
||||
@@ -76,40 +75,43 @@ export default function Header() {
|
||||
|
||||
{/* action buttons (top right) */}
|
||||
<div className="flex items-center">
|
||||
<div v-if="messages.length > 0" className="dropdown dropdown-end">
|
||||
{/* "..." button */}
|
||||
<button
|
||||
tabIndex={0}
|
||||
role="button"
|
||||
className="btn m-1"
|
||||
disabled={isCurrConvGenerating}
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="16"
|
||||
height="16"
|
||||
fill="currentColor"
|
||||
className="bi bi-three-dots-vertical"
|
||||
viewBox="0 0 16 16"
|
||||
{viewingChat && (
|
||||
<div className="dropdown dropdown-end">
|
||||
{/* "..." button */}
|
||||
<button
|
||||
tabIndex={0}
|
||||
role="button"
|
||||
className="btn m-1"
|
||||
disabled={isCurrConvGenerating}
|
||||
>
|
||||
<path d="M9.5 13a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0m0-5a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0m0-5a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0" />
|
||||
</svg>
|
||||
</button>
|
||||
{/* dropdown menu */}
|
||||
<ul
|
||||
tabIndex={0}
|
||||
className="dropdown-content menu bg-base-100 rounded-box z-[1] w-52 p-2 shadow"
|
||||
>
|
||||
<li onClick={downloadConversation}>
|
||||
<a>Download</a>
|
||||
</li>
|
||||
<li className="text-error" onClick={removeConversation}>
|
||||
<a>Delete</a>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="16"
|
||||
height="16"
|
||||
fill="currentColor"
|
||||
className="bi bi-three-dots-vertical"
|
||||
viewBox="0 0 16 16"
|
||||
>
|
||||
<path d="M9.5 13a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0m0-5a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0m0-5a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0" />
|
||||
</svg>
|
||||
</button>
|
||||
{/* dropdown menu */}
|
||||
<ul
|
||||
tabIndex={0}
|
||||
className="dropdown-content menu bg-base-100 rounded-box z-[1] w-52 p-2 shadow"
|
||||
>
|
||||
<li onClick={downloadConversation}>
|
||||
<a>Download</a>
|
||||
</li>
|
||||
<li className="text-error" onClick={removeConversation}>
|
||||
<a>Delete</a>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="tooltip tooltip-bottom" data-tip="Settings">
|
||||
<button className="btn" onClick={() => setShowSettingDialog(true)}>
|
||||
<button className="btn" onClick={() => setShowSettings(true)}>
|
||||
{/* settings button */}
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
@@ -172,11 +174,6 @@ export default function Header() {
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<SettingDialog
|
||||
show={showSettingDialog}
|
||||
onClose={() => setShowSettingDialog(false)}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -9,8 +9,16 @@ import 'katex/dist/katex.min.css';
|
||||
import { classNames, copyStr } from '../utils/misc';
|
||||
import { ElementContent, Root } from 'hast';
|
||||
import { visit } from 'unist-util-visit';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import { CanvasType } from '../utils/types';
|
||||
|
||||
export default function MarkdownDisplay({ content }: { content: string }) {
|
||||
export default function MarkdownDisplay({
|
||||
content,
|
||||
isGenerating,
|
||||
}: {
|
||||
content: string;
|
||||
isGenerating?: boolean;
|
||||
}) {
|
||||
const preprocessedContent = useMemo(
|
||||
() => preprocessLaTeX(content),
|
||||
[content]
|
||||
@@ -21,7 +29,11 @@ export default function MarkdownDisplay({ content }: { content: string }) {
|
||||
rehypePlugins={[rehypeHightlight, rehypeKatex, rehypeCustomCopyButton]}
|
||||
components={{
|
||||
button: (props) => (
|
||||
<CopyCodeButton {...props} origContent={preprocessedContent} />
|
||||
<CodeBlockButtons
|
||||
{...props}
|
||||
isGenerating={isGenerating}
|
||||
origContent={preprocessedContent}
|
||||
/>
|
||||
),
|
||||
// note: do not use "pre", "p" or other basic html elements here, it will cause the node to re-render when the message is being generated (this should be a bug with react-markdown, not sure how to fix it)
|
||||
}}
|
||||
@@ -31,11 +43,12 @@ export default function MarkdownDisplay({ content }: { content: string }) {
|
||||
);
|
||||
}
|
||||
|
||||
const CopyCodeButton: React.ElementType<
|
||||
const CodeBlockButtons: React.ElementType<
|
||||
React.ClassAttributes<HTMLButtonElement> &
|
||||
React.HTMLAttributes<HTMLButtonElement> &
|
||||
ExtraProps & { origContent: string }
|
||||
> = ({ node, origContent }) => {
|
||||
ExtraProps & { origContent: string; isGenerating?: boolean }
|
||||
> = ({ node, origContent, isGenerating }) => {
|
||||
const { config } = useAppContext();
|
||||
const startOffset = node?.position?.start.offset ?? 0;
|
||||
const endOffset = node?.position?.end.offset ?? 0;
|
||||
|
||||
@@ -48,14 +61,33 @@ const CopyCodeButton: React.ElementType<
|
||||
[origContent, startOffset, endOffset]
|
||||
);
|
||||
|
||||
const codeLanguage = useMemo(
|
||||
() =>
|
||||
origContent
|
||||
.substring(startOffset, startOffset + 10)
|
||||
.match(/^```([^\n]+)\n/)?.[1] ?? '',
|
||||
[origContent, startOffset]
|
||||
);
|
||||
|
||||
const canRunCode =
|
||||
!isGenerating &&
|
||||
config.pyIntepreterEnabled &&
|
||||
codeLanguage.startsWith('py');
|
||||
|
||||
return (
|
||||
<div
|
||||
className={classNames({
|
||||
'text-right sticky top-4 mb-2 mr-2 h-0': true,
|
||||
'text-right sticky top-[7em] mb-2 mr-2 h-0': true,
|
||||
'display-none': !node?.position,
|
||||
})}
|
||||
>
|
||||
<CopyButton className="badge btn-mini" content={copiedContent} />
|
||||
{canRunCode && (
|
||||
<RunPyCodeButton
|
||||
className="badge btn-mini ml-2"
|
||||
content={copiedContent}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -82,6 +114,31 @@ export const CopyButton = ({
|
||||
);
|
||||
};
|
||||
|
||||
export const RunPyCodeButton = ({
|
||||
content,
|
||||
className,
|
||||
}: {
|
||||
content: string;
|
||||
className?: string;
|
||||
}) => {
|
||||
const { setCanvasData } = useAppContext();
|
||||
return (
|
||||
<>
|
||||
<button
|
||||
className={className}
|
||||
onClick={() =>
|
||||
setCanvasData({
|
||||
type: CanvasType.PY_INTERPRETER,
|
||||
content,
|
||||
})
|
||||
}
|
||||
>
|
||||
▶️ Run
|
||||
</button>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* This injects the "button" element before each "pre" element.
|
||||
* The actual button will be replaced with a react component in the MarkdownDisplay.
|
||||
@@ -95,9 +152,7 @@ function rehypeCustomCopyButton() {
|
||||
// replace current node
|
||||
preNode.properties.visited = 'true';
|
||||
node.tagName = 'div';
|
||||
node.properties = {
|
||||
className: 'relative my-4',
|
||||
};
|
||||
node.properties = {};
|
||||
// add node for button
|
||||
const btnNode: ElementContent = {
|
||||
type: 'element',
|
||||
|
||||
@@ -3,18 +3,27 @@ import { useAppContext } from '../utils/app.context';
|
||||
import { CONFIG_DEFAULT, CONFIG_INFO } from '../Config';
|
||||
import { isDev } from '../Config';
|
||||
import StorageUtils from '../utils/storage';
|
||||
import { isBoolean, isNumeric, isString } from '../utils/misc';
|
||||
import { classNames, isBoolean, isNumeric, isString } from '../utils/misc';
|
||||
import {
|
||||
BeakerIcon,
|
||||
ChatBubbleOvalLeftEllipsisIcon,
|
||||
Cog6ToothIcon,
|
||||
FunnelIcon,
|
||||
HandRaisedIcon,
|
||||
SquaresPlusIcon,
|
||||
} from '@heroicons/react/24/outline';
|
||||
import { OpenInNewTab } from '../utils/common';
|
||||
|
||||
type SettKey = keyof typeof CONFIG_DEFAULT;
|
||||
|
||||
const COMMON_SAMPLER_KEYS: SettKey[] = [
|
||||
const BASIC_KEYS: SettKey[] = [
|
||||
'temperature',
|
||||
'top_k',
|
||||
'top_p',
|
||||
'min_p',
|
||||
'max_tokens',
|
||||
];
|
||||
const OTHER_SAMPLER_KEYS: SettKey[] = [
|
||||
const SAMPLER_KEYS: SettKey[] = [
|
||||
'dynatemp_range',
|
||||
'dynatemp_exponent',
|
||||
'typical_p',
|
||||
@@ -32,6 +41,223 @@ const PENALTY_KEYS: SettKey[] = [
|
||||
'dry_penalty_last_n',
|
||||
];
|
||||
|
||||
enum SettingInputType {
|
||||
SHORT_INPUT,
|
||||
LONG_INPUT,
|
||||
CHECKBOX,
|
||||
CUSTOM,
|
||||
}
|
||||
|
||||
interface SettingFieldInput {
|
||||
type: Exclude<SettingInputType, SettingInputType.CUSTOM>;
|
||||
label: string | React.ReactElement;
|
||||
help?: string | React.ReactElement;
|
||||
key: SettKey;
|
||||
}
|
||||
|
||||
interface SettingFieldCustom {
|
||||
type: SettingInputType.CUSTOM;
|
||||
key: SettKey;
|
||||
component:
|
||||
| string
|
||||
| React.FC<{
|
||||
value: string | boolean | number;
|
||||
onChange: (value: string) => void;
|
||||
}>;
|
||||
}
|
||||
|
||||
interface SettingSection {
|
||||
title: React.ReactElement;
|
||||
fields: (SettingFieldInput | SettingFieldCustom)[];
|
||||
}
|
||||
|
||||
const ICON_CLASSNAME = 'w-4 h-4 mr-1 inline';
|
||||
|
||||
const SETTING_SECTIONS: SettingSection[] = [
|
||||
{
|
||||
title: (
|
||||
<>
|
||||
<Cog6ToothIcon className={ICON_CLASSNAME} />
|
||||
General
|
||||
</>
|
||||
),
|
||||
fields: [
|
||||
{
|
||||
type: SettingInputType.SHORT_INPUT,
|
||||
label: 'API Key',
|
||||
key: 'apiKey',
|
||||
},
|
||||
{
|
||||
type: SettingInputType.LONG_INPUT,
|
||||
label: 'System Message (will be disabled if left empty)',
|
||||
key: 'systemMessage',
|
||||
},
|
||||
...BASIC_KEYS.map(
|
||||
(key) =>
|
||||
({
|
||||
type: SettingInputType.SHORT_INPUT,
|
||||
label: key,
|
||||
key,
|
||||
}) as SettingFieldInput
|
||||
),
|
||||
],
|
||||
},
|
||||
{
|
||||
title: (
|
||||
<>
|
||||
<FunnelIcon className={ICON_CLASSNAME} />
|
||||
Samplers
|
||||
</>
|
||||
),
|
||||
fields: [
|
||||
{
|
||||
type: SettingInputType.SHORT_INPUT,
|
||||
label: 'Samplers queue',
|
||||
key: 'samplers',
|
||||
},
|
||||
...SAMPLER_KEYS.map(
|
||||
(key) =>
|
||||
({
|
||||
type: SettingInputType.SHORT_INPUT,
|
||||
label: key,
|
||||
key,
|
||||
}) as SettingFieldInput
|
||||
),
|
||||
],
|
||||
},
|
||||
{
|
||||
title: (
|
||||
<>
|
||||
<HandRaisedIcon className={ICON_CLASSNAME} />
|
||||
Penalties
|
||||
</>
|
||||
),
|
||||
fields: PENALTY_KEYS.map((key) => ({
|
||||
type: SettingInputType.SHORT_INPUT,
|
||||
label: key,
|
||||
key,
|
||||
})),
|
||||
},
|
||||
{
|
||||
title: (
|
||||
<>
|
||||
<ChatBubbleOvalLeftEllipsisIcon className={ICON_CLASSNAME} />
|
||||
Reasoning
|
||||
</>
|
||||
),
|
||||
fields: [
|
||||
{
|
||||
type: SettingInputType.CHECKBOX,
|
||||
label: 'Expand though process by default for generating message',
|
||||
key: 'showThoughtInProgress',
|
||||
},
|
||||
{
|
||||
type: SettingInputType.CHECKBOX,
|
||||
label:
|
||||
'Exclude thought process when sending request to API (Recommended for DeepSeek-R1)',
|
||||
key: 'excludeThoughtOnReq',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
title: (
|
||||
<>
|
||||
<SquaresPlusIcon className={ICON_CLASSNAME} />
|
||||
Advanced
|
||||
</>
|
||||
),
|
||||
fields: [
|
||||
{
|
||||
type: SettingInputType.CUSTOM,
|
||||
key: 'custom', // dummy key, won't be used
|
||||
component: () => {
|
||||
const debugImportDemoConv = async () => {
|
||||
const res = await fetch('/demo-conversation.json');
|
||||
const demoConv = await res.json();
|
||||
StorageUtils.remove(demoConv.id);
|
||||
for (const msg of demoConv.messages) {
|
||||
StorageUtils.appendMsg(demoConv.id, msg);
|
||||
}
|
||||
};
|
||||
return (
|
||||
<button className="btn" onClick={debugImportDemoConv}>
|
||||
(debug) Import demo conversation
|
||||
</button>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
type: SettingInputType.CHECKBOX,
|
||||
label: 'Show tokens per second',
|
||||
key: 'showTokensPerSecond',
|
||||
},
|
||||
{
|
||||
type: SettingInputType.LONG_INPUT,
|
||||
label: (
|
||||
<>
|
||||
Custom JSON config (For more info, refer to{' '}
|
||||
<OpenInNewTab href="https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md">
|
||||
server documentation
|
||||
</OpenInNewTab>
|
||||
)
|
||||
</>
|
||||
),
|
||||
key: 'custom',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
title: (
|
||||
<>
|
||||
<BeakerIcon className={ICON_CLASSNAME} />
|
||||
Experimental
|
||||
</>
|
||||
),
|
||||
fields: [
|
||||
{
|
||||
type: SettingInputType.CUSTOM,
|
||||
key: 'custom', // dummy key, won't be used
|
||||
component: () => (
|
||||
<>
|
||||
<p className="mb-8">
|
||||
Experimental features are not guaranteed to work correctly.
|
||||
<br />
|
||||
<br />
|
||||
If you encounter any problems, create a{' '}
|
||||
<OpenInNewTab href="https://github.com/ggerganov/llama.cpp/issues/new?template=019-bug-misc.yml">
|
||||
Bug (misc.)
|
||||
</OpenInNewTab>{' '}
|
||||
report on Github. Please also specify <b>webui/experimental</b> on
|
||||
the report title and include screenshots.
|
||||
<br />
|
||||
<br />
|
||||
Some features may require packages downloaded from CDN, so they
|
||||
need internet connection.
|
||||
</p>
|
||||
</>
|
||||
),
|
||||
},
|
||||
{
|
||||
type: SettingInputType.CHECKBOX,
|
||||
label: (
|
||||
<>
|
||||
<b>Enable Python interpreter</b>
|
||||
<br />
|
||||
<small className="text-xs">
|
||||
This feature uses{' '}
|
||||
<OpenInNewTab href="https://pyodide.org">pyodide</OpenInNewTab>,
|
||||
downloaded from CDN. To use this feature, ask the LLM to generate
|
||||
python code inside a markdown code block. You will see a "Run"
|
||||
button on the code block, near the "Copy" button.
|
||||
</small>
|
||||
</>
|
||||
),
|
||||
key: 'pyIntepreterEnabled',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
export default function SettingDialog({
|
||||
show,
|
||||
onClose,
|
||||
@@ -40,6 +266,7 @@ export default function SettingDialog({
|
||||
onClose: () => void;
|
||||
}) {
|
||||
const { config, saveConfig } = useAppContext();
|
||||
const [sectionIdx, setSectionIdx] = useState(0);
|
||||
|
||||
// clone the config object to prevent direct mutation
|
||||
const [localConfig, setLocalConfig] = useState<typeof CONFIG_DEFAULT>(
|
||||
@@ -92,181 +319,109 @@ export default function SettingDialog({
|
||||
onClose();
|
||||
};
|
||||
|
||||
const debugImportDemoConv = async () => {
|
||||
const res = await fetch('/demo-conversation.json');
|
||||
const demoConv = await res.json();
|
||||
StorageUtils.remove(demoConv.id);
|
||||
for (const msg of demoConv.messages) {
|
||||
StorageUtils.appendMsg(demoConv.id, msg);
|
||||
}
|
||||
onClose();
|
||||
};
|
||||
|
||||
const onChange = (key: SettKey) => (value: string | boolean) => {
|
||||
// note: we do not perform validation here, because we may get incomplete value as user is still typing it
|
||||
setLocalConfig({ ...localConfig, [key]: value });
|
||||
};
|
||||
|
||||
return (
|
||||
<dialog className={`modal ${show ? 'modal-open' : ''}`}>
|
||||
<div className="modal-box">
|
||||
<dialog className={classNames({ modal: true, 'modal-open': show })}>
|
||||
<div className="modal-box w-11/12 max-w-3xl">
|
||||
<h3 className="text-lg font-bold mb-6">Settings</h3>
|
||||
<div className="h-[calc(90vh-12rem)] overflow-y-auto">
|
||||
<p className="opacity-40 mb-6">
|
||||
Settings below are saved in browser's localStorage
|
||||
</p>
|
||||
|
||||
<SettingsModalShortInput
|
||||
configKey="apiKey"
|
||||
configDefault={CONFIG_DEFAULT}
|
||||
value={localConfig.apiKey}
|
||||
onChange={onChange('apiKey')}
|
||||
/>
|
||||
|
||||
<label className="form-control mb-2">
|
||||
<div className="label">
|
||||
System Message (will be disabled if left empty)
|
||||
</div>
|
||||
<textarea
|
||||
className="textarea textarea-bordered h-24"
|
||||
placeholder={`Default: ${CONFIG_DEFAULT.systemMessage}`}
|
||||
value={localConfig.systemMessage}
|
||||
onChange={(e) => onChange('systemMessage')(e.target.value)}
|
||||
/>
|
||||
</label>
|
||||
|
||||
{COMMON_SAMPLER_KEYS.map((key) => (
|
||||
<SettingsModalShortInput
|
||||
key={key}
|
||||
configKey={key}
|
||||
configDefault={CONFIG_DEFAULT}
|
||||
value={localConfig[key]}
|
||||
onChange={onChange(key)}
|
||||
/>
|
||||
))}
|
||||
|
||||
<details className="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
|
||||
<summary className="collapse-title font-bold">
|
||||
Other sampler settings
|
||||
</summary>
|
||||
<div className="collapse-content">
|
||||
<SettingsModalShortInput
|
||||
label="Samplers queue"
|
||||
configKey="samplers"
|
||||
configDefault={CONFIG_DEFAULT}
|
||||
value={localConfig.samplers}
|
||||
onChange={onChange('samplers')}
|
||||
/>
|
||||
{OTHER_SAMPLER_KEYS.map((key) => (
|
||||
<SettingsModalShortInput
|
||||
key={key}
|
||||
configKey={key}
|
||||
configDefault={CONFIG_DEFAULT}
|
||||
value={localConfig[key]}
|
||||
onChange={onChange(key)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<details className="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
|
||||
<summary className="collapse-title font-bold">
|
||||
Penalties settings
|
||||
</summary>
|
||||
<div className="collapse-content">
|
||||
{PENALTY_KEYS.map((key) => (
|
||||
<SettingsModalShortInput
|
||||
key={key}
|
||||
configKey={key}
|
||||
configDefault={CONFIG_DEFAULT}
|
||||
value={localConfig[key]}
|
||||
onChange={onChange(key)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<details className="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
|
||||
<summary className="collapse-title font-bold">
|
||||
Reasoning models
|
||||
</summary>
|
||||
<div className="collapse-content">
|
||||
<div className="flex flex-row items-center mb-2">
|
||||
<input
|
||||
type="checkbox"
|
||||
className="checkbox"
|
||||
checked={localConfig.showThoughtInProgress}
|
||||
onChange={(e) =>
|
||||
onChange('showThoughtInProgress')(e.target.checked)
|
||||
}
|
||||
/>
|
||||
<span className="ml-4">
|
||||
Expand though process by default for generating message
|
||||
</span>
|
||||
<div className="flex flex-col md:flex-row h-[calc(90vh-12rem)]">
|
||||
{/* Left panel, showing sections - Desktop version */}
|
||||
<div className="hidden md:flex flex-col items-stretch pr-4 mr-4 border-r-2 border-base-200">
|
||||
{SETTING_SECTIONS.map((section, idx) => (
|
||||
<div
|
||||
key={idx}
|
||||
className={classNames({
|
||||
'btn btn-ghost justify-start font-normal w-44 mb-1': true,
|
||||
'btn-active': sectionIdx === idx,
|
||||
})}
|
||||
onClick={() => setSectionIdx(idx)}
|
||||
dir="auto"
|
||||
>
|
||||
{section.title}
|
||||
</div>
|
||||
<div className="flex flex-row items-center mb-2">
|
||||
<input
|
||||
type="checkbox"
|
||||
className="checkbox"
|
||||
checked={localConfig.excludeThoughtOnReq}
|
||||
onChange={(e) =>
|
||||
onChange('excludeThoughtOnReq')(e.target.checked)
|
||||
}
|
||||
/>
|
||||
<span className="ml-4">
|
||||
Exclude thought process when sending request to API
|
||||
(Recommended for DeepSeek-R1)
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</details>
|
||||
))}
|
||||
</div>
|
||||
|
||||
<details className="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
|
||||
<summary className="collapse-title font-bold">
|
||||
Advanced config
|
||||
</summary>
|
||||
<div className="collapse-content">
|
||||
{/* this button only shows in dev mode, used to import a demo conversation to test message rendering */}
|
||||
{isDev && (
|
||||
<div className="flex flex-row items-center mb-2">
|
||||
<button className="btn" onClick={debugImportDemoConv}>
|
||||
(debug) Import demo conversation
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
<div className="flex flex-row items-center mb-2">
|
||||
<input
|
||||
type="checkbox"
|
||||
className="checkbox"
|
||||
checked={localConfig.showTokensPerSecond}
|
||||
onChange={(e) =>
|
||||
onChange('showTokensPerSecond')(e.target.checked)
|
||||
}
|
||||
/>
|
||||
<span className="ml-4">Show tokens per second</span>
|
||||
</div>
|
||||
<label className="form-control mb-2">
|
||||
<div className="label inline">
|
||||
Custom JSON config (For more info, refer to{' '}
|
||||
<a
|
||||
className="underline"
|
||||
href="https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
{/* Left panel, showing sections - Mobile version */}
|
||||
<div className="md:hidden flex flex-row gap-2 mb-4">
|
||||
<details className="dropdown">
|
||||
<summary className="btn bt-sm w-full m-1">
|
||||
{SETTING_SECTIONS[sectionIdx].title}
|
||||
</summary>
|
||||
<ul className="menu dropdown-content bg-base-100 rounded-box z-[1] w-52 p-2 shadow">
|
||||
{SETTING_SECTIONS.map((section, idx) => (
|
||||
<div
|
||||
key={idx}
|
||||
className={classNames({
|
||||
'btn btn-ghost justify-start font-normal': true,
|
||||
'btn-active': sectionIdx === idx,
|
||||
})}
|
||||
onClick={() => setSectionIdx(idx)}
|
||||
dir="auto"
|
||||
>
|
||||
server documentation
|
||||
</a>
|
||||
)
|
||||
</div>
|
||||
<textarea
|
||||
className="textarea textarea-bordered h-24"
|
||||
placeholder='Example: { "mirostat": 1, "min_p": 0.1 }'
|
||||
value={localConfig.custom}
|
||||
onChange={(e) => onChange('custom')(e.target.value)}
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
</details>
|
||||
{section.title}
|
||||
</div>
|
||||
))}
|
||||
</ul>
|
||||
</details>
|
||||
</div>
|
||||
|
||||
{/* Right panel, showing setting fields */}
|
||||
<div className="grow overflow-y-auto px-4">
|
||||
{SETTING_SECTIONS[sectionIdx].fields.map((field, idx) => {
|
||||
const key = `${sectionIdx}-${idx}`;
|
||||
if (field.type === SettingInputType.SHORT_INPUT) {
|
||||
return (
|
||||
<SettingsModalShortInput
|
||||
key={key}
|
||||
configKey={field.key}
|
||||
value={localConfig[field.key]}
|
||||
onChange={onChange(field.key)}
|
||||
label={field.label as string}
|
||||
/>
|
||||
);
|
||||
} else if (field.type === SettingInputType.LONG_INPUT) {
|
||||
return (
|
||||
<SettingsModalLongInput
|
||||
key={key}
|
||||
configKey={field.key}
|
||||
value={localConfig[field.key].toString()}
|
||||
onChange={onChange(field.key)}
|
||||
label={field.label as string}
|
||||
/>
|
||||
);
|
||||
} else if (field.type === SettingInputType.CHECKBOX) {
|
||||
return (
|
||||
<SettingsModalCheckbox
|
||||
key={key}
|
||||
configKey={field.key}
|
||||
value={!!localConfig[field.key]}
|
||||
onChange={onChange(field.key)}
|
||||
label={field.label as string}
|
||||
/>
|
||||
);
|
||||
} else if (field.type === SettingInputType.CUSTOM) {
|
||||
return (
|
||||
<div key={key} className="mb-2">
|
||||
{typeof field.component === 'string'
|
||||
? field.component
|
||||
: field.component({
|
||||
value: localConfig[field.key],
|
||||
onChange: onChange(field.key),
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
})}
|
||||
|
||||
<p className="opacity-40 mb-6 text-sm mt-8">
|
||||
Settings are saved in browser's localStorage
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="modal-action">
|
||||
@@ -285,37 +440,97 @@ export default function SettingDialog({
|
||||
);
|
||||
}
|
||||
|
||||
function SettingsModalShortInput({
|
||||
function SettingsModalLongInput({
|
||||
configKey,
|
||||
configDefault,
|
||||
value,
|
||||
onChange,
|
||||
label,
|
||||
}: {
|
||||
configKey: SettKey;
|
||||
configDefault: typeof CONFIG_DEFAULT;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
value: any;
|
||||
value: string;
|
||||
onChange: (value: string) => void;
|
||||
label?: string;
|
||||
}) {
|
||||
return (
|
||||
<label className="input input-bordered join-item grow flex items-center gap-2 mb-2">
|
||||
<div className="dropdown dropdown-hover">
|
||||
<div tabIndex={0} role="button" className="font-bold">
|
||||
{label || configKey}
|
||||
</div>
|
||||
<div className="dropdown-content menu bg-base-100 rounded-box z-10 w-64 p-2 shadow mt-4">
|
||||
{CONFIG_INFO[configKey] ?? '(no help message available)'}
|
||||
</div>
|
||||
</div>
|
||||
<input
|
||||
type="text"
|
||||
className="grow"
|
||||
placeholder={`Default: ${configDefault[configKey] || 'none'}`}
|
||||
<label className="form-control mb-2">
|
||||
<div className="label inline">{label || configKey}</div>
|
||||
<textarea
|
||||
className="textarea textarea-bordered h-24"
|
||||
placeholder={`Default: ${CONFIG_DEFAULT[configKey] || 'none'}`}
|
||||
value={value}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
/>
|
||||
</label>
|
||||
);
|
||||
}
|
||||
|
||||
function SettingsModalShortInput({
|
||||
configKey,
|
||||
value,
|
||||
onChange,
|
||||
label,
|
||||
}: {
|
||||
configKey: SettKey;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
value: any;
|
||||
onChange: (value: string) => void;
|
||||
label?: string;
|
||||
}) {
|
||||
const helpMsg = CONFIG_INFO[configKey];
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* on mobile, we simply show the help message here */}
|
||||
{helpMsg && (
|
||||
<div className="block md:hidden mb-1">
|
||||
<b>{label || configKey}</b>
|
||||
<br />
|
||||
<p className="text-xs">{helpMsg}</p>
|
||||
</div>
|
||||
)}
|
||||
<label className="input input-bordered join-item grow flex items-center gap-2 mb-2">
|
||||
<div className="dropdown dropdown-hover">
|
||||
<div tabIndex={0} role="button" className="font-bold hidden md:block">
|
||||
{label || configKey}
|
||||
</div>
|
||||
{helpMsg && (
|
||||
<div className="dropdown-content menu bg-base-100 rounded-box z-10 w-64 p-2 shadow mt-4">
|
||||
{helpMsg}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<input
|
||||
type="text"
|
||||
className="grow"
|
||||
placeholder={`Default: ${CONFIG_DEFAULT[configKey] || 'none'}`}
|
||||
value={value}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
/>
|
||||
</label>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
function SettingsModalCheckbox({
|
||||
configKey,
|
||||
value,
|
||||
onChange,
|
||||
label,
|
||||
}: {
|
||||
configKey: SettKey;
|
||||
value: boolean;
|
||||
onChange: (value: boolean) => void;
|
||||
label: string;
|
||||
}) {
|
||||
return (
|
||||
<div className="flex flex-row items-center mb-2">
|
||||
<input
|
||||
type="checkbox"
|
||||
className="toggle"
|
||||
checked={value}
|
||||
onChange={(e) => onChange(e.target.checked)}
|
||||
/>
|
||||
<span className="ml-4">{label || configKey}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useEffect, useMemo, useState } from 'react';
|
||||
import { useEffect, useState } from 'react';
|
||||
import { classNames } from '../utils/misc';
|
||||
import { Conversation } from '../utils/types';
|
||||
import StorageUtils from '../utils/storage';
|
||||
@@ -7,16 +7,17 @@ import { useNavigate, useParams } from 'react-router';
|
||||
export default function Sidebar() {
|
||||
const params = useParams();
|
||||
const navigate = useNavigate();
|
||||
const currConv = useMemo(
|
||||
() => StorageUtils.getOneConversation(params.convId ?? ''),
|
||||
[params.convId]
|
||||
);
|
||||
|
||||
const [conversations, setConversations] = useState<Conversation[]>([]);
|
||||
const [currConv, setCurrConv] = useState<Conversation | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const handleConversationChange = () => {
|
||||
setConversations(StorageUtils.getAllConversations());
|
||||
StorageUtils.getOneConversation(params.convId ?? '').then(setCurrConv);
|
||||
}, [params.convId]);
|
||||
|
||||
useEffect(() => {
|
||||
const handleConversationChange = async () => {
|
||||
setConversations(await StorageUtils.getAllConversations());
|
||||
};
|
||||
StorageUtils.onConversationChanged(handleConversationChange);
|
||||
handleConversationChange();
|
||||
@@ -82,11 +83,11 @@ export default function Sidebar() {
|
||||
onClick={() => navigate(`/chat/${conv.id}`)}
|
||||
dir="auto"
|
||||
>
|
||||
<span className="truncate">{conv.messages[0].content}</span>
|
||||
<span className="truncate">{conv.name}</span>
|
||||
</div>
|
||||
))}
|
||||
<div className="text-center text-xs opacity-40 mt-auto mx-4">
|
||||
Conversations are saved to browser's localStorage
|
||||
Conversations are saved to browser's IndexedDB
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -45,6 +45,9 @@
|
||||
/* Highlight.js */
|
||||
[data-color-scheme='light'] {
|
||||
@include meta.load-css('highlight.js/styles/stackoverflow-light');
|
||||
.dark-color {
|
||||
@apply bg-base-content text-base-100;
|
||||
}
|
||||
}
|
||||
[data-color-scheme='dark'] {
|
||||
@include meta.load-css('highlight.js/styles/stackoverflow-dark');
|
||||
@@ -52,6 +55,9 @@
|
||||
[data-color-scheme='auto'] {
|
||||
@media (prefers-color-scheme: light) {
|
||||
@include meta.load-css('highlight.js/styles/stackoverflow-light');
|
||||
.dark-color {
|
||||
@apply bg-base-content text-base-100;
|
||||
}
|
||||
}
|
||||
@media (prefers-color-scheme: dark) {
|
||||
@include meta.load-css('highlight.js/styles/stackoverflow-dark');
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
import React, { createContext, useContext, useEffect, useState } from 'react';
|
||||
import { APIMessage, Conversation, Message, PendingMessage } from './types';
|
||||
import {
|
||||
APIMessage,
|
||||
CanvasData,
|
||||
Conversation,
|
||||
Message,
|
||||
PendingMessage,
|
||||
ViewingChat,
|
||||
} from './types';
|
||||
import StorageUtils from './storage';
|
||||
import {
|
||||
filterThoughtFromMsgs,
|
||||
@@ -7,46 +14,65 @@ import {
|
||||
getSSEStreamAsync,
|
||||
} from './misc';
|
||||
import { BASE_URL, CONFIG_DEFAULT, isDev } from '../Config';
|
||||
import { matchPath, useLocation } from 'react-router';
|
||||
import { matchPath, useLocation, useNavigate } from 'react-router';
|
||||
|
||||
interface AppContextValue {
|
||||
viewingConversation: Conversation | null;
|
||||
// conversations and messages
|
||||
viewingChat: ViewingChat | null;
|
||||
pendingMessages: Record<Conversation['id'], PendingMessage>;
|
||||
isGenerating: (convId: string) => boolean;
|
||||
sendMessage: (
|
||||
convId: string,
|
||||
convId: string | null,
|
||||
leafNodeId: Message['id'] | null,
|
||||
content: string,
|
||||
onChunk?: CallbackGeneratedChunk
|
||||
onChunk: CallbackGeneratedChunk
|
||||
) => Promise<boolean>;
|
||||
stopGenerating: (convId: string) => void;
|
||||
replaceMessageAndGenerate: (
|
||||
convId: string,
|
||||
origMsgId: Message['id'],
|
||||
content?: string,
|
||||
onChunk?: CallbackGeneratedChunk
|
||||
parentNodeId: Message['id'], // the parent node of the message to be replaced
|
||||
content: string | null,
|
||||
onChunk: CallbackGeneratedChunk
|
||||
) => Promise<void>;
|
||||
|
||||
// canvas
|
||||
canvasData: CanvasData | null;
|
||||
setCanvasData: (data: CanvasData | null) => void;
|
||||
|
||||
// config
|
||||
config: typeof CONFIG_DEFAULT;
|
||||
saveConfig: (config: typeof CONFIG_DEFAULT) => void;
|
||||
showSettings: boolean;
|
||||
setShowSettings: (show: boolean) => void;
|
||||
}
|
||||
|
||||
// for now, this callback is only used for scrolling to the bottom of the chat
|
||||
type CallbackGeneratedChunk = () => void;
|
||||
// this callback is used for scrolling to the bottom of the chat and switching to the last node
|
||||
export type CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => void;
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const AppContext = createContext<AppContextValue>({} as any);
|
||||
|
||||
const getViewingChat = async (convId: string): Promise<ViewingChat | null> => {
|
||||
const conv = await StorageUtils.getOneConversation(convId);
|
||||
if (!conv) return null;
|
||||
return {
|
||||
conv: conv,
|
||||
// all messages from all branches, not filtered by last node
|
||||
messages: await StorageUtils.getMessages(convId),
|
||||
};
|
||||
};
|
||||
|
||||
export const AppContextProvider = ({
|
||||
children,
|
||||
}: {
|
||||
children: React.ReactElement;
|
||||
}) => {
|
||||
const { pathname } = useLocation();
|
||||
const navigate = useNavigate();
|
||||
const params = matchPath('/chat/:convId', pathname);
|
||||
const convId = params?.params?.convId;
|
||||
|
||||
const [viewingConversation, setViewingConversation] =
|
||||
useState<Conversation | null>(null);
|
||||
const [viewingChat, setViewingChat] = useState<ViewingChat | null>(null);
|
||||
const [pendingMessages, setPendingMessages] = useState<
|
||||
Record<Conversation['id'], PendingMessage>
|
||||
>({});
|
||||
@@ -54,14 +80,19 @@ export const AppContextProvider = ({
|
||||
Record<Conversation['id'], AbortController>
|
||||
>({});
|
||||
const [config, setConfig] = useState(StorageUtils.getConfig());
|
||||
const [canvasData, setCanvasData] = useState<CanvasData | null>(null);
|
||||
const [showSettings, setShowSettings] = useState(false);
|
||||
|
||||
// handle change when the convId from URL is changed
|
||||
useEffect(() => {
|
||||
const handleConversationChange = (changedConvId: string) => {
|
||||
// also reset the canvas data
|
||||
setCanvasData(null);
|
||||
const handleConversationChange = async (changedConvId: string) => {
|
||||
if (changedConvId !== convId) return;
|
||||
setViewingConversation(StorageUtils.getOneConversation(convId));
|
||||
setViewingChat(await getViewingChat(changedConvId));
|
||||
};
|
||||
StorageUtils.onConversationChanged(handleConversationChange);
|
||||
setViewingConversation(StorageUtils.getOneConversation(convId ?? ''));
|
||||
getViewingChat(convId ?? '').then(setViewingChat);
|
||||
return () => {
|
||||
StorageUtils.offConversationChanged(handleConversationChange);
|
||||
};
|
||||
@@ -99,23 +130,39 @@ export const AppContextProvider = ({
|
||||
|
||||
const generateMessage = async (
|
||||
convId: string,
|
||||
onChunk?: CallbackGeneratedChunk
|
||||
leafNodeId: Message['id'],
|
||||
onChunk: CallbackGeneratedChunk
|
||||
) => {
|
||||
if (isGenerating(convId)) return;
|
||||
|
||||
const config = StorageUtils.getConfig();
|
||||
const currConversation = StorageUtils.getOneConversation(convId);
|
||||
const currConversation = await StorageUtils.getOneConversation(convId);
|
||||
if (!currConversation) {
|
||||
throw new Error('Current conversation is not found');
|
||||
}
|
||||
|
||||
const currMessages = StorageUtils.filterByLeafNodeId(
|
||||
await StorageUtils.getMessages(convId),
|
||||
leafNodeId,
|
||||
false
|
||||
);
|
||||
const abortController = new AbortController();
|
||||
setAbort(convId, abortController);
|
||||
|
||||
if (!currMessages) {
|
||||
throw new Error('Current messages are not found');
|
||||
}
|
||||
|
||||
const pendingId = Date.now() + 1;
|
||||
let pendingMsg: PendingMessage = {
|
||||
id: Date.now() + 1,
|
||||
id: pendingId,
|
||||
convId,
|
||||
type: 'text',
|
||||
timestamp: pendingId,
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
parent: leafNodeId,
|
||||
children: [],
|
||||
};
|
||||
setPending(convId, pendingMsg);
|
||||
|
||||
@@ -125,7 +172,7 @@ export const AppContextProvider = ({
|
||||
...(config.systemMessage.length === 0
|
||||
? []
|
||||
: [{ role: 'system', content: config.systemMessage } as APIMessage]),
|
||||
...normalizeMsgsForAPI(currConversation?.messages ?? []),
|
||||
...normalizeMsgsForAPI(currMessages),
|
||||
];
|
||||
if (config.excludeThoughtOnReq) {
|
||||
messages = filterThoughtFromMsgs(messages);
|
||||
@@ -186,8 +233,7 @@ export const AppContextProvider = ({
|
||||
const lastContent = pendingMsg.content || '';
|
||||
if (addedContent) {
|
||||
pendingMsg = {
|
||||
id: pendingMsg.id,
|
||||
role: 'assistant',
|
||||
...pendingMsg,
|
||||
content: lastContent + addedContent,
|
||||
};
|
||||
}
|
||||
@@ -202,7 +248,7 @@ export const AppContextProvider = ({
|
||||
};
|
||||
}
|
||||
setPending(convId, pendingMsg);
|
||||
onChunk?.();
|
||||
onChunk(); // don't need to switch node for pending message
|
||||
}
|
||||
} catch (err) {
|
||||
setPending(convId, null);
|
||||
@@ -217,37 +263,53 @@ export const AppContextProvider = ({
|
||||
}
|
||||
}
|
||||
|
||||
if (pendingMsg.content) {
|
||||
StorageUtils.appendMsg(currConversation.id, {
|
||||
id: pendingMsg.id,
|
||||
content: pendingMsg.content,
|
||||
role: pendingMsg.role,
|
||||
timings: pendingMsg.timings,
|
||||
});
|
||||
if (pendingMsg.content !== null) {
|
||||
await StorageUtils.appendMsg(pendingMsg as Message, leafNodeId);
|
||||
}
|
||||
setPending(convId, null);
|
||||
onChunk?.(); // trigger scroll to bottom
|
||||
onChunk(pendingId); // trigger scroll to bottom and switch to the last node
|
||||
};
|
||||
|
||||
const sendMessage = async (
|
||||
convId: string,
|
||||
convId: string | null,
|
||||
leafNodeId: Message['id'] | null,
|
||||
content: string,
|
||||
onChunk?: CallbackGeneratedChunk
|
||||
onChunk: CallbackGeneratedChunk
|
||||
): Promise<boolean> => {
|
||||
if (isGenerating(convId) || content.trim().length === 0) return false;
|
||||
if (isGenerating(convId ?? '') || content.trim().length === 0) return false;
|
||||
|
||||
StorageUtils.appendMsg(convId, {
|
||||
id: Date.now(),
|
||||
role: 'user',
|
||||
content,
|
||||
});
|
||||
if (convId === null || convId.length === 0 || leafNodeId === null) {
|
||||
const conv = await StorageUtils.createConversation(
|
||||
content.substring(0, 256)
|
||||
);
|
||||
convId = conv.id;
|
||||
leafNodeId = conv.currNode;
|
||||
// if user is creating a new conversation, redirect to the new conversation
|
||||
navigate(`/chat/${convId}`);
|
||||
}
|
||||
|
||||
const now = Date.now();
|
||||
const currMsgId = now;
|
||||
StorageUtils.appendMsg(
|
||||
{
|
||||
id: currMsgId,
|
||||
timestamp: now,
|
||||
type: 'text',
|
||||
convId,
|
||||
role: 'user',
|
||||
content,
|
||||
parent: leafNodeId,
|
||||
children: [],
|
||||
},
|
||||
leafNodeId
|
||||
);
|
||||
onChunk(currMsgId);
|
||||
|
||||
try {
|
||||
await generateMessage(convId, onChunk);
|
||||
await generateMessage(convId, currMsgId, onChunk);
|
||||
return true;
|
||||
} catch (_) {
|
||||
// rollback
|
||||
StorageUtils.popMsg(convId);
|
||||
// TODO: rollback
|
||||
}
|
||||
return false;
|
||||
};
|
||||
@@ -260,22 +322,33 @@ export const AppContextProvider = ({
|
||||
// if content is undefined, we remove last assistant message
|
||||
const replaceMessageAndGenerate = async (
|
||||
convId: string,
|
||||
origMsgId: Message['id'],
|
||||
content?: string,
|
||||
onChunk?: CallbackGeneratedChunk
|
||||
parentNodeId: Message['id'], // the parent node of the message to be replaced
|
||||
content: string | null,
|
||||
onChunk: CallbackGeneratedChunk
|
||||
) => {
|
||||
if (isGenerating(convId)) return;
|
||||
|
||||
StorageUtils.filterAndKeepMsgs(convId, (msg) => msg.id < origMsgId);
|
||||
if (content) {
|
||||
StorageUtils.appendMsg(convId, {
|
||||
id: Date.now(),
|
||||
role: 'user',
|
||||
content,
|
||||
});
|
||||
if (content !== null) {
|
||||
const now = Date.now();
|
||||
const currMsgId = now;
|
||||
StorageUtils.appendMsg(
|
||||
{
|
||||
id: currMsgId,
|
||||
timestamp: now,
|
||||
type: 'text',
|
||||
convId,
|
||||
role: 'user',
|
||||
content,
|
||||
parent: parentNodeId,
|
||||
children: [],
|
||||
},
|
||||
parentNodeId
|
||||
);
|
||||
parentNodeId = currMsgId;
|
||||
}
|
||||
onChunk(parentNodeId);
|
||||
|
||||
await generateMessage(convId, onChunk);
|
||||
await generateMessage(convId, parentNodeId, onChunk);
|
||||
};
|
||||
|
||||
const saveConfig = (config: typeof CONFIG_DEFAULT) => {
|
||||
@@ -287,13 +360,17 @@ export const AppContextProvider = ({
|
||||
<AppContext.Provider
|
||||
value={{
|
||||
isGenerating,
|
||||
viewingConversation,
|
||||
viewingChat,
|
||||
pendingMessages,
|
||||
sendMessage,
|
||||
stopGenerating,
|
||||
replaceMessageAndGenerate,
|
||||
canvasData,
|
||||
setCanvasData,
|
||||
config,
|
||||
saveConfig,
|
||||
showSettings,
|
||||
setShowSettings,
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
|
||||
38
examples/server/webui/src/utils/common.tsx
Normal file
38
examples/server/webui/src/utils/common.tsx
Normal file
@@ -0,0 +1,38 @@
|
||||
export const XCloseButton: React.ElementType<
|
||||
React.ClassAttributes<HTMLButtonElement> &
|
||||
React.HTMLAttributes<HTMLButtonElement>
|
||||
> = ({ className, ...props }) => (
|
||||
<button className={`btn btn-square btn-sm ${className ?? ''}`} {...props}>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="h-6 w-6"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
strokeWidth="2"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
);
|
||||
|
||||
export const OpenInNewTab = ({
|
||||
href,
|
||||
children,
|
||||
}: {
|
||||
href: string;
|
||||
children: string;
|
||||
}) => (
|
||||
<a
|
||||
className="underline"
|
||||
href={href}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
{children}
|
||||
</a>
|
||||
);
|
||||
@@ -4,7 +4,6 @@ import { APIMessage, Message } from './types';
|
||||
|
||||
// ponyfill for missing ReadableStream asyncIterator on Safari
|
||||
import { asyncIterator } from '@sec-ant/readable-stream/ponyfill/asyncIterator';
|
||||
import { isDev } from '../Config';
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
export const isString = (x: any) => !!x.toLowerCase;
|
||||
@@ -23,7 +22,7 @@ export async function* getSSEStreamAsync(fetchResponse: Response) {
|
||||
.pipeThrough(new TextLineStream());
|
||||
// @ts-expect-error asyncIterator complains about type, but it should work
|
||||
for await (const line of asyncIterator(lines)) {
|
||||
if (isDev) console.log({ line });
|
||||
//if (isDev) console.log({ line });
|
||||
if (line.startsWith('data:') && !line.endsWith('[DONE]')) {
|
||||
const data = JSON.parse(line.slice(5));
|
||||
yield data;
|
||||
@@ -55,7 +54,7 @@ export const copyStr = (textToCopy: string) => {
|
||||
/**
|
||||
* filter out redundant fields upon sending to API
|
||||
*/
|
||||
export function normalizeMsgsForAPI(messages: Message[]) {
|
||||
export function normalizeMsgsForAPI(messages: Readonly<Message[]>) {
|
||||
return messages.map((msg) => {
|
||||
return {
|
||||
role: msg.role,
|
||||
@@ -85,3 +84,26 @@ export function classNames(classes: Record<string, boolean>): string {
|
||||
.map(([key, _]) => key)
|
||||
.join(' ');
|
||||
}
|
||||
|
||||
export const delay = (ms: number) =>
|
||||
new Promise((resolve) => setTimeout(resolve, ms));
|
||||
|
||||
export const throttle = <T extends unknown[]>(
|
||||
callback: (...args: T) => void,
|
||||
delay: number
|
||||
) => {
|
||||
let isWaiting = false;
|
||||
|
||||
return (...args: T) => {
|
||||
if (isWaiting) {
|
||||
return;
|
||||
}
|
||||
|
||||
callback(...args);
|
||||
isWaiting = true;
|
||||
|
||||
setTimeout(() => {
|
||||
isWaiting = false;
|
||||
}, delay);
|
||||
};
|
||||
};
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
// format: { [convId]: { id: string, lastModified: number, messages: [...] } }
|
||||
|
||||
import { CONFIG_DEFAULT } from '../Config';
|
||||
import { Conversation, Message } from './types';
|
||||
import { Conversation, Message, TimingReport } from './types';
|
||||
import Dexie, { Table } from 'dexie';
|
||||
|
||||
const event = new EventTarget();
|
||||
|
||||
@@ -17,85 +18,154 @@ const dispatchConversationChange = (convId: string) => {
|
||||
);
|
||||
};
|
||||
|
||||
const db = new Dexie('LlamacppWebui') as Dexie & {
|
||||
conversations: Table<Conversation>;
|
||||
messages: Table<Message>;
|
||||
};
|
||||
|
||||
// https://dexie.org/docs/Version/Version.stores()
|
||||
db.version(1).stores({
|
||||
// Unlike SQL, you don’t need to specify all properties but only the one you wish to index.
|
||||
conversations: '&id, lastModified',
|
||||
messages: '&id, convId, [convId+id], timestamp',
|
||||
});
|
||||
|
||||
// convId is a string prefixed with 'conv-'
|
||||
const StorageUtils = {
|
||||
/**
|
||||
* manage conversations
|
||||
*/
|
||||
getAllConversations(): Conversation[] {
|
||||
const res = [];
|
||||
for (const key in localStorage) {
|
||||
if (key.startsWith('conv-')) {
|
||||
res.push(JSON.parse(localStorage.getItem(key) ?? '{}'));
|
||||
}
|
||||
}
|
||||
res.sort((a, b) => b.lastModified - a.lastModified);
|
||||
return res;
|
||||
async getAllConversations(): Promise<Conversation[]> {
|
||||
await migrationLStoIDB().catch(console.error); // noop if already migrated
|
||||
return (await db.conversations.toArray()).sort(
|
||||
(a, b) => b.lastModified - a.lastModified
|
||||
);
|
||||
},
|
||||
/**
|
||||
* can return null if convId does not exist
|
||||
*/
|
||||
getOneConversation(convId: string): Conversation | null {
|
||||
return JSON.parse(localStorage.getItem(convId) || 'null');
|
||||
async getOneConversation(convId: string): Promise<Conversation | null> {
|
||||
return (await db.conversations.where('id').equals(convId).first()) ?? null;
|
||||
},
|
||||
/**
|
||||
* if convId does not exist, create one
|
||||
* get all message nodes in a conversation
|
||||
*/
|
||||
appendMsg(convId: string, msg: Message): void {
|
||||
if (msg.content === null) return;
|
||||
const conv = StorageUtils.getOneConversation(convId) || {
|
||||
id: convId,
|
||||
lastModified: Date.now(),
|
||||
messages: [],
|
||||
async getMessages(convId: string): Promise<Message[]> {
|
||||
return await db.messages.where({ convId }).toArray();
|
||||
},
|
||||
/**
|
||||
* use in conjunction with getMessages to filter messages by leafNodeId
|
||||
* includeRoot: whether to include the root node in the result
|
||||
* if node with leafNodeId does not exist, return the path with the latest timestamp
|
||||
*/
|
||||
filterByLeafNodeId(
|
||||
msgs: Readonly<Message[]>,
|
||||
leafNodeId: Message['id'],
|
||||
includeRoot: boolean
|
||||
): Readonly<Message[]> {
|
||||
const res: Message[] = [];
|
||||
const nodeMap = new Map<Message['id'], Message>();
|
||||
for (const msg of msgs) {
|
||||
nodeMap.set(msg.id, msg);
|
||||
}
|
||||
let startNode: Message | undefined = nodeMap.get(leafNodeId);
|
||||
if (!startNode) {
|
||||
// if not found, we return the path with the latest timestamp
|
||||
let latestTime = -1;
|
||||
for (const msg of msgs) {
|
||||
if (msg.timestamp > latestTime) {
|
||||
startNode = msg;
|
||||
latestTime = msg.timestamp;
|
||||
}
|
||||
}
|
||||
}
|
||||
// traverse the path from leafNodeId to root
|
||||
// startNode can never be undefined here
|
||||
let currNode: Message | undefined = startNode;
|
||||
while (currNode) {
|
||||
if (currNode.type !== 'root' || (currNode.type === 'root' && includeRoot))
|
||||
res.push(currNode);
|
||||
currNode = nodeMap.get(currNode.parent ?? -1);
|
||||
}
|
||||
res.sort((a, b) => a.timestamp - b.timestamp);
|
||||
return res;
|
||||
},
|
||||
/**
|
||||
* create a new conversation with a default root node
|
||||
*/
|
||||
async createConversation(name: string): Promise<Conversation> {
|
||||
const now = Date.now();
|
||||
const msgId = now;
|
||||
const conv: Conversation = {
|
||||
id: `conv-${now}`,
|
||||
lastModified: now,
|
||||
currNode: msgId,
|
||||
name,
|
||||
};
|
||||
conv.messages.push(msg);
|
||||
conv.lastModified = Date.now();
|
||||
localStorage.setItem(convId, JSON.stringify(conv));
|
||||
dispatchConversationChange(convId);
|
||||
await db.conversations.add(conv);
|
||||
// create a root node
|
||||
await db.messages.add({
|
||||
id: msgId,
|
||||
convId: conv.id,
|
||||
type: 'root',
|
||||
timestamp: now,
|
||||
role: 'system',
|
||||
content: '',
|
||||
parent: -1,
|
||||
children: [],
|
||||
});
|
||||
return conv;
|
||||
},
|
||||
/**
|
||||
* Get new conversation id
|
||||
* if convId does not exist, throw an error
|
||||
*/
|
||||
getNewConvId(): string {
|
||||
return `conv-${Date.now()}`;
|
||||
async appendMsg(
|
||||
msg: Exclude<Message, 'parent' | 'children'>,
|
||||
parentNodeId: Message['id']
|
||||
): Promise<void> {
|
||||
if (msg.content === null) return;
|
||||
const { convId } = msg;
|
||||
await db.transaction('rw', db.conversations, db.messages, async () => {
|
||||
const conv = await StorageUtils.getOneConversation(convId);
|
||||
const parentMsg = await db.messages
|
||||
.where({ convId, id: parentNodeId })
|
||||
.first();
|
||||
// update the currNode of conversation
|
||||
if (!conv) {
|
||||
throw new Error(`Conversation ${convId} does not exist`);
|
||||
}
|
||||
if (!parentMsg) {
|
||||
throw new Error(
|
||||
`Parent message ID ${parentNodeId} does not exist in conversation ${convId}`
|
||||
);
|
||||
}
|
||||
await db.conversations.update(convId, {
|
||||
lastModified: Date.now(),
|
||||
currNode: msg.id,
|
||||
});
|
||||
// update parent
|
||||
await db.messages.update(parentNodeId, {
|
||||
children: [...parentMsg.children, msg.id],
|
||||
});
|
||||
// create message
|
||||
await db.messages.add({
|
||||
...msg,
|
||||
parent: parentNodeId,
|
||||
children: [],
|
||||
});
|
||||
});
|
||||
dispatchConversationChange(convId);
|
||||
},
|
||||
/**
|
||||
* remove conversation by id
|
||||
*/
|
||||
remove(convId: string): void {
|
||||
localStorage.removeItem(convId);
|
||||
async remove(convId: string): Promise<void> {
|
||||
await db.transaction('rw', db.conversations, db.messages, async () => {
|
||||
await db.conversations.delete(convId);
|
||||
await db.messages.where({ convId }).delete();
|
||||
});
|
||||
dispatchConversationChange(convId);
|
||||
},
|
||||
/**
|
||||
* remove all conversations
|
||||
*/
|
||||
filterAndKeepMsgs(
|
||||
convId: string,
|
||||
predicate: (msg: Message) => boolean
|
||||
): void {
|
||||
const conv = StorageUtils.getOneConversation(convId);
|
||||
if (!conv) return;
|
||||
conv.messages = conv.messages.filter(predicate);
|
||||
conv.lastModified = Date.now();
|
||||
localStorage.setItem(convId, JSON.stringify(conv));
|
||||
dispatchConversationChange(convId);
|
||||
},
|
||||
/**
|
||||
* remove last message from conversation
|
||||
*/
|
||||
popMsg(convId: string): Message | undefined {
|
||||
const conv = StorageUtils.getOneConversation(convId);
|
||||
if (!conv) return;
|
||||
const msg = conv.messages.pop();
|
||||
conv.lastModified = Date.now();
|
||||
if (conv.messages.length === 0) {
|
||||
StorageUtils.remove(convId);
|
||||
} else {
|
||||
localStorage.setItem(convId, JSON.stringify(conv));
|
||||
}
|
||||
dispatchConversationChange(convId);
|
||||
return msg;
|
||||
},
|
||||
|
||||
// event listeners
|
||||
onConversationChanged(callback: CallbackConversationChanged) {
|
||||
@@ -136,3 +206,79 @@ const StorageUtils = {
|
||||
};
|
||||
|
||||
export default StorageUtils;
|
||||
|
||||
// Migration from localStorage to IndexedDB
|
||||
|
||||
// these are old types, LS prefix stands for LocalStorage
|
||||
interface LSConversation {
|
||||
id: string; // format: `conv-{timestamp}`
|
||||
lastModified: number; // timestamp from Date.now()
|
||||
messages: LSMessage[];
|
||||
}
|
||||
interface LSMessage {
|
||||
id: number;
|
||||
role: 'user' | 'assistant' | 'system';
|
||||
content: string;
|
||||
timings?: TimingReport;
|
||||
}
|
||||
async function migrationLStoIDB() {
|
||||
if (localStorage.getItem('migratedToIDB')) return;
|
||||
const res: LSConversation[] = [];
|
||||
for (const key in localStorage) {
|
||||
if (key.startsWith('conv-')) {
|
||||
res.push(JSON.parse(localStorage.getItem(key) ?? '{}'));
|
||||
}
|
||||
}
|
||||
if (res.length === 0) return;
|
||||
await db.transaction('rw', db.conversations, db.messages, async () => {
|
||||
let migratedCount = 0;
|
||||
for (const conv of res) {
|
||||
const { id: convId, lastModified, messages } = conv;
|
||||
const firstMsg = messages[0];
|
||||
const lastMsg = messages.at(-1);
|
||||
if (messages.length < 2 || !firstMsg || !lastMsg) {
|
||||
console.log(
|
||||
`Skipping conversation ${convId} with ${messages.length} messages`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
const name = firstMsg.content ?? '(no messages)';
|
||||
await db.conversations.add({
|
||||
id: convId,
|
||||
lastModified,
|
||||
currNode: lastMsg.id,
|
||||
name,
|
||||
});
|
||||
const rootId = messages[0].id - 2;
|
||||
await db.messages.add({
|
||||
id: rootId,
|
||||
convId: convId,
|
||||
type: 'root',
|
||||
timestamp: rootId,
|
||||
role: 'system',
|
||||
content: '',
|
||||
parent: -1,
|
||||
children: [firstMsg.id],
|
||||
});
|
||||
for (let i = 0; i < messages.length; i++) {
|
||||
const msg = messages[i];
|
||||
await db.messages.add({
|
||||
...msg,
|
||||
type: 'text',
|
||||
convId: convId,
|
||||
timestamp: msg.id,
|
||||
parent: i === 0 ? rootId : messages[i - 1].id,
|
||||
children: i === messages.length - 1 ? [] : [messages[i + 1].id],
|
||||
});
|
||||
}
|
||||
migratedCount++;
|
||||
console.log(
|
||||
`Migrated conversation ${convId} with ${messages.length} messages`
|
||||
);
|
||||
}
|
||||
console.log(
|
||||
`Migrated ${migratedCount} conversations from localStorage to IndexedDB`
|
||||
);
|
||||
localStorage.setItem('migratedToIDB', '1');
|
||||
});
|
||||
}
|
||||
|
||||
@@ -5,11 +5,46 @@ export interface TimingReport {
|
||||
predicted_ms: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* What is conversation "branching"? It is a feature that allows the user to edit an old message in the history, while still keeping the conversation flow.
|
||||
* Inspired by ChatGPT / Claude / Hugging Chat where you edit a message, a new branch of the conversation is created, and the old message is still visible.
|
||||
*
|
||||
* We use the same node-based structure like other chat UIs, where each message has a parent and children. A "root" message is the first message in a conversation, which will not be displayed in the UI.
|
||||
*
|
||||
* root
|
||||
* ├── message 1
|
||||
* │ └── message 2
|
||||
* │ └── message 3
|
||||
* └── message 4
|
||||
* └── message 5
|
||||
*
|
||||
* In the above example, assuming that user wants to edit message 2, a new branch will be created:
|
||||
*
|
||||
* ├── message 2
|
||||
* │ └── message 3
|
||||
* └── message 6
|
||||
*
|
||||
* Message 2 and 6 are siblings, and message 6 is the new branch.
|
||||
*
|
||||
* We only need to know the last node (aka leaf) to get the current branch. In the above example, message 5 is the leaf of branch containing message 4 and 5.
|
||||
*
|
||||
* For the implementation:
|
||||
* - StorageUtils.getMessages() returns list of all nodes
|
||||
* - StorageUtils.filterByLeafNodeId() filters the list of nodes from a given leaf node
|
||||
*/
|
||||
|
||||
// Note: the term "message" and "node" are used interchangeably in this context
|
||||
export interface Message {
|
||||
id: number;
|
||||
convId: string;
|
||||
type: 'text' | 'root';
|
||||
timestamp: number; // timestamp from Date.now()
|
||||
role: 'user' | 'assistant' | 'system';
|
||||
content: string;
|
||||
timings?: TimingReport;
|
||||
// node based system for branching
|
||||
parent: Message['id'];
|
||||
children: Message['id'][];
|
||||
}
|
||||
|
||||
export type APIMessage = Pick<Message, 'role' | 'content'>;
|
||||
@@ -17,9 +52,26 @@ export type APIMessage = Pick<Message, 'role' | 'content'>;
|
||||
export interface Conversation {
|
||||
id: string; // format: `conv-{timestamp}`
|
||||
lastModified: number; // timestamp from Date.now()
|
||||
messages: Message[];
|
||||
currNode: Message['id']; // the current message node being viewed
|
||||
name: string;
|
||||
}
|
||||
|
||||
export interface ViewingChat {
|
||||
conv: Readonly<Conversation>;
|
||||
messages: Readonly<Message[]>;
|
||||
}
|
||||
|
||||
export type PendingMessage = Omit<Message, 'content'> & {
|
||||
content: string | null;
|
||||
};
|
||||
|
||||
export enum CanvasType {
|
||||
PY_INTERPRETER,
|
||||
}
|
||||
|
||||
export interface CanvasPyInterpreter {
|
||||
type: CanvasType.PY_INTERPRETER;
|
||||
content: string;
|
||||
}
|
||||
|
||||
export type CanvasData = CanvasPyInterpreter;
|
||||
|
||||
@@ -72,5 +72,9 @@ export default defineConfig({
|
||||
proxy: {
|
||||
'/v1': 'http://localhost:8080',
|
||||
},
|
||||
headers: {
|
||||
'Cross-Origin-Embedder-Policy': 'require-corp',
|
||||
'Cross-Origin-Opener-Policy': 'same-origin',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@@ -10,8 +10,6 @@ extern "C" {
|
||||
#define GGML_VK_NAME "Vulkan"
|
||||
#define GGML_VK_MAX_DEVICES 16
|
||||
|
||||
GGML_BACKEND_API void ggml_vk_instance_init(void);
|
||||
|
||||
// backend API
|
||||
GGML_BACKEND_API ggml_backend_t ggml_backend_vk_init(size_t dev_num);
|
||||
|
||||
|
||||
@@ -198,7 +198,7 @@
|
||||
|
||||
#ifndef __GNUC__
|
||||
# define GGML_ATTRIBUTE_FORMAT(...)
|
||||
#elif defined(__MINGW32__)
|
||||
#elif defined(__MINGW32__) && !defined(__clang__)
|
||||
# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||
#else
|
||||
# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||
|
||||
@@ -473,7 +473,6 @@ GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128)
|
||||
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
||||
GGML_TABLE_END()
|
||||
|
||||
//#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A // lowest compute capability for integer intrinsics
|
||||
GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
|
||||
0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
|
||||
0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
|
||||
@@ -508,7 +507,6 @@ GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
|
||||
0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
|
||||
0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
|
||||
GGML_TABLE_END()
|
||||
//#endif
|
||||
|
||||
|
||||
GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,10 +7,8 @@
|
||||
#include "ggml-cpu-impl.h"
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-quants.h"
|
||||
#include "ggml-cpu-quants.h"
|
||||
#include "ggml-threading.h"
|
||||
#include "amx/amx.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
@@ -1291,7 +1289,7 @@ struct ggml_threadpool {
|
||||
atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
|
||||
atomic_int GGML_CACHE_ALIGN n_barrier;
|
||||
atomic_int GGML_CACHE_ALIGN n_barrier_passed;
|
||||
atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
|
||||
atomic_int GGML_CACHE_ALIGN current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
|
||||
|
||||
// these are atomic as an annotation for thread-sanitizer
|
||||
atomic_bool stop; // Used for stopping the threadpool altogether
|
||||
@@ -7490,6 +7488,7 @@ UseGgmlGemm1:;
|
||||
if (src1->type != vec_dot_type) {
|
||||
char * wdata = params->wdata;
|
||||
|
||||
const size_t nbw0 = ggml_type_size(vec_dot_type);
|
||||
const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
|
||||
const size_t nbw2 = nbw1*ne11;
|
||||
const size_t nbw3 = nbw2*ne12;
|
||||
@@ -7497,6 +7496,7 @@ UseGgmlGemm1:;
|
||||
assert(params->wsize >= ne13*nbw3);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
#if 0
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
|
||||
@@ -7506,6 +7506,20 @@ UseGgmlGemm1:;
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
size_t bs = ggml_blck_size(vec_dot_type);
|
||||
int64_t ne10_block_start = (ith * ne10/bs) / nth;
|
||||
int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth;
|
||||
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
|
||||
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
|
||||
(ne10_block_end - ne10_block_start) * bs);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
if (ith == 0) {
|
||||
@@ -7593,7 +7607,6 @@ UseGgmlGemm2:;
|
||||
if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
|
||||
num_rows_per_vec_dot = 1;
|
||||
}
|
||||
|
||||
ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
|
||||
|
||||
if (nth >= nchunk0 * nchunk1) {
|
||||
@@ -7606,6 +7619,84 @@ UseGgmlGemm2:;
|
||||
|
||||
// ggml_compute_forward_mul_mat_id
|
||||
|
||||
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ids->ne[0]*ids->ne[1] + (i1)]
|
||||
|
||||
struct mmid_row_mapping {
|
||||
int32_t i1;
|
||||
int32_t i2;
|
||||
};
|
||||
|
||||
static void ggml_compute_forward_mul_mat_id_one_chunk(
|
||||
struct ggml_tensor * dst,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
const struct ggml_tensor * ids,
|
||||
const int64_t cur_a,
|
||||
const int64_t ir0_start,
|
||||
const int64_t ir0_end,
|
||||
const int64_t ir1_start,
|
||||
const int64_t ir1_end,
|
||||
const char * src0_cur,
|
||||
const struct mmid_row_mapping * matrix_rows,
|
||||
const size_t row_size,
|
||||
const bool src1_cont,
|
||||
const void * wdata) {
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const enum ggml_type type = src0->type;
|
||||
|
||||
ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
|
||||
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
|
||||
|
||||
const int64_t blck_0 = 16;
|
||||
const int64_t blck_1 = 16;
|
||||
|
||||
float tmp[16];
|
||||
|
||||
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
||||
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
||||
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ++ir1) {
|
||||
const int64_t _i12 = ir1; // logical row index for this expert
|
||||
|
||||
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
|
||||
const int id = row_mapping.i1; // selected expert index
|
||||
|
||||
const int64_t i11 = id % ne11;
|
||||
const int64_t i12 = row_mapping.i2; // row index in src1
|
||||
|
||||
const int64_t i1 = id; // selected expert index
|
||||
const int64_t i2 = i12; // row
|
||||
|
||||
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
|
||||
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
|
||||
// the original src1 data pointer, so we should index using the indices directly
|
||||
// TODO: this is a bit of a hack, we should probably have a better way to handle this
|
||||
const char * src1_col = (const char *) wdata +
|
||||
(src1_cont || src1->type != vec_dot_type
|
||||
? (i11 + i12*ne11)*row_size
|
||||
: (i11*nb11 + i12*nb12));
|
||||
|
||||
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
|
||||
|
||||
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
|
||||
vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
|
||||
}
|
||||
|
||||
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
|
||||
|
||||
void * ptr = *p;
|
||||
ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
|
||||
*p = (void *) ((char *) ptr + size);
|
||||
return ptr;
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_mul_mat_id(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
@@ -7623,7 +7714,6 @@ static void ggml_compute_forward_mul_mat_id(
|
||||
|
||||
const bool src1_cont = ggml_is_contiguous(src1);
|
||||
|
||||
ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
|
||||
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
|
||||
ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
|
||||
|
||||
@@ -7641,21 +7731,27 @@ static void ggml_compute_forward_mul_mat_id(
|
||||
const int n_ids = ids->ne[0]; // n_expert_used
|
||||
const int n_as = ne02; // n_expert
|
||||
|
||||
char * wdata_src1_end = (src1->type == vec_dot_type) ?
|
||||
(char *) params->wdata :
|
||||
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
|
||||
void * wdata_cur = params->wdata;
|
||||
|
||||
struct mmid_row_mapping {
|
||||
int32_t i1;
|
||||
int32_t i2;
|
||||
};
|
||||
if (src1->type != vec_dot_type) {
|
||||
incr_ptr_aligned(&wdata_cur, ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
|
||||
}
|
||||
|
||||
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
|
||||
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
|
||||
int64_t * matrix_row_counts = // [n_as]
|
||||
incr_ptr_aligned(&wdata_cur, n_as*sizeof(int64_t), sizeof(int64_t));
|
||||
|
||||
struct mmid_row_mapping * matrix_rows = // [n_as][ids->ne[0]*ids->ne[1]]
|
||||
incr_ptr_aligned(&wdata_cur, n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping), sizeof(int64_t));
|
||||
|
||||
char (*atomic_current_chunk)[CACHE_LINE_SIZE] = // [n_as]
|
||||
incr_ptr_aligned(&wdata_cur, CACHE_LINE_SIZE * n_as, CACHE_LINE_SIZE);
|
||||
|
||||
GGML_ASSERT(params->wsize >= (size_t)((char *) wdata_cur - (char *) params->wdata));
|
||||
|
||||
if (src1->type != vec_dot_type) {
|
||||
char * wdata = params->wdata;
|
||||
|
||||
const size_t nbw0 = ggml_type_size(vec_dot_type);
|
||||
const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
|
||||
const size_t nbw2 = nbw1*ne11;
|
||||
const size_t nbw3 = nbw2*ne12;
|
||||
@@ -7663,19 +7759,32 @@ static void ggml_compute_forward_mul_mat_id(
|
||||
assert(params->wsize >= ne13*nbw3);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
#if 0
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
|
||||
for (int64_t i12 = ith; i12 < ne12; i12 += nth) {
|
||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
||||
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
|
||||
ne10);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
size_t bs = ggml_blck_size(vec_dot_type);
|
||||
int64_t ne10_block_start = (ith * ne10/bs) / nth;
|
||||
int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth;
|
||||
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
|
||||
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
|
||||
(ne10_block_end - ne10_block_start) * bs);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
|
||||
|
||||
if (ith == 0) {
|
||||
// initialize matrix_row_counts
|
||||
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
|
||||
@@ -7693,9 +7802,14 @@ static void ggml_compute_forward_mul_mat_id(
|
||||
}
|
||||
}
|
||||
|
||||
// reset current_chunk
|
||||
for (int cur_a = ith; cur_a < n_as; cur_a += nth) {
|
||||
atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
|
||||
*current_chunk_ctr = nth;
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
// compute each matrix multiplication in sequence
|
||||
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
|
||||
const int64_t cne1 = matrix_row_counts[cur_a];
|
||||
|
||||
@@ -7703,84 +7817,64 @@ static void ggml_compute_forward_mul_mat_id(
|
||||
continue;
|
||||
}
|
||||
|
||||
const char * src0_cur = (const char *) src0->data + cur_a*nb02;
|
||||
|
||||
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||
const char * src0_cur = (const char *) src0->data + cur_a * nb02;
|
||||
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||
|
||||
const int64_t nr0 = ne01; // src0 rows
|
||||
const int64_t nr1 = cne1; // src1 rows
|
||||
const int64_t nr0 = ne01;
|
||||
const int64_t nr1 = cne1;
|
||||
|
||||
// distribute the thread work across the inner or outer loop based on which one is larger
|
||||
int chunk_size = 16;
|
||||
if (nr0 == 1 || nr1 == 1) {
|
||||
chunk_size = 64;
|
||||
}
|
||||
|
||||
const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
|
||||
const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
|
||||
#if defined(__aarch64__)
|
||||
// disable for ARM
|
||||
const bool disable_chunking = true;
|
||||
#else
|
||||
// disable for NUMA
|
||||
const bool disable_chunking = ggml_is_numa();
|
||||
#endif // defined(__aarch64__)
|
||||
|
||||
const int64_t ith0 = ith % nth0;
|
||||
const int64_t ith1 = ith / nth0;
|
||||
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
|
||||
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
|
||||
|
||||
const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
|
||||
const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
|
||||
if (nchunk0 * nchunk1 < nth * 4 || disable_chunking) {
|
||||
nchunk0 = nr0 > nr1 ? nth : 1;
|
||||
nchunk1 = nr0 > nr1 ? 1 : nth;
|
||||
}
|
||||
|
||||
const int64_t ir010 = dr0*ith0;
|
||||
const int64_t ir011 = MIN(ir010 + dr0, nr0);
|
||||
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
|
||||
const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
|
||||
|
||||
const int64_t ir110 = dr1*ith1;
|
||||
const int64_t ir111 = MIN(ir110 + dr1, nr1);
|
||||
int current_chunk = ith;
|
||||
|
||||
// threads with no work simply yield (not sure if it helps)
|
||||
//if (ir010 >= ir011 || ir110 >= ir111) {
|
||||
// sched_yield();
|
||||
// continue;
|
||||
//}
|
||||
atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
|
||||
|
||||
// block-tiling attempt
|
||||
const int64_t blck_0 = 16;
|
||||
const int64_t blck_1 = 16;
|
||||
while (current_chunk < nchunk0 * nchunk1) {
|
||||
const int64_t ith0 = current_chunk % nchunk0;
|
||||
const int64_t ith1 = current_chunk / nchunk0;
|
||||
|
||||
// attempt to reduce false-sharing (does not seem to make a difference)
|
||||
float tmp[16];
|
||||
const int64_t ir0_start = dr0 * ith0;
|
||||
const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
|
||||
|
||||
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
|
||||
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
|
||||
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
|
||||
const int64_t _i12 = ir1; // logical row index for this expert
|
||||
const int64_t ir1_start = dr1 * ith1;
|
||||
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
|
||||
|
||||
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
|
||||
const int id = row_mapping.i1; // selected expert index
|
||||
ggml_compute_forward_mul_mat_id_one_chunk(
|
||||
dst, src0, src1, ids, cur_a,
|
||||
ir0_start, ir0_end, ir1_start, ir1_end,
|
||||
src0_cur, matrix_rows, row_size, src1_cont, wdata
|
||||
);
|
||||
|
||||
const int64_t i11 = id % ne11;
|
||||
const int64_t i12 = row_mapping.i2; // row index in src1
|
||||
|
||||
const int64_t i1 = id; // selected expert index
|
||||
const int64_t i2 = i12; // row
|
||||
|
||||
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
|
||||
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
|
||||
// the original src1 data pointer, so we should index using the indices directly
|
||||
// TODO: this is a bit of a hack, we should probably have a better way to handle this
|
||||
const char * src1_col = (const char *) wdata +
|
||||
(src1_cont || src1->type != vec_dot_type
|
||||
? (i11 + i12*ne11)*row_size
|
||||
: (i11*nb11 + i12*nb12));
|
||||
|
||||
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
|
||||
|
||||
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
|
||||
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
|
||||
//}
|
||||
|
||||
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
|
||||
vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
|
||||
}
|
||||
|
||||
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
|
||||
}
|
||||
if (nth >= nchunk0 * nchunk1) {
|
||||
break;
|
||||
}
|
||||
|
||||
current_chunk = atomic_fetch_add_explicit(current_chunk_ctr, 1, memory_order_relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
#undef MMID_MATRIX_ROW
|
||||
}
|
||||
|
||||
// ggml_compute_forward_out_prod
|
||||
@@ -9074,10 +9168,6 @@ static void ggml_compute_forward_clamp_f32(
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
if (params->ith != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
float min;
|
||||
float max;
|
||||
memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
|
||||
@@ -13717,14 +13807,19 @@ struct ggml_cplan ggml_graph_plan(
|
||||
cur = 0;
|
||||
const struct ggml_tensor * src0 = node->src[0];
|
||||
const struct ggml_tensor * src1 = node->src[1];
|
||||
const struct ggml_tensor * ids = node->src[2];
|
||||
const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
|
||||
if (src1->type != vec_dot_type) {
|
||||
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
|
||||
}
|
||||
const int n_as = src0->ne[2];
|
||||
cur += GGML_PAD(cur, sizeof(int64_t)); // align
|
||||
cur += n_as * sizeof(int64_t); // matrix_row_counts
|
||||
cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
|
||||
// src1
|
||||
if (src1->type != vec_dot_type) {
|
||||
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)) + sizeof(int64_t);
|
||||
}
|
||||
// matrix_row_counts
|
||||
cur += n_as * sizeof(int64_t) + sizeof(int64_t);
|
||||
// matrix_rows
|
||||
cur += n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping) + sizeof(int64_t);
|
||||
// atomic_current_chunk
|
||||
cur += CACHE_LINE_SIZE*n_as + CACHE_LINE_SIZE;
|
||||
} break;
|
||||
case GGML_OP_OUT_PROD:
|
||||
{
|
||||
@@ -13856,9 +13951,13 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
tp->ec = GGML_STATUS_ABORTED;
|
||||
}
|
||||
|
||||
ggml_barrier(state->threadpool);
|
||||
if (node_n + 1 < cgraph->n_nodes) {
|
||||
ggml_barrier(state->threadpool);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_barrier(state->threadpool);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@@ -284,14 +284,14 @@ struct ggml_backend_cpu_device_context {
|
||||
&hKey) == ERROR_SUCCESS) {
|
||||
DWORD cpu_brand_size = 0;
|
||||
if (RegQueryValueExA(hKey,
|
||||
TEXT("ProcessorNameString"),
|
||||
"ProcessorNameString",
|
||||
NULL,
|
||||
NULL,
|
||||
NULL,
|
||||
&cpu_brand_size) == ERROR_SUCCESS) {
|
||||
description.resize(cpu_brand_size);
|
||||
if (RegQueryValueExA(hKey,
|
||||
TEXT("ProcessorNameString"),
|
||||
"ProcessorNameString",
|
||||
NULL,
|
||||
NULL,
|
||||
(LPBYTE)&description[0], // NOLINT
|
||||
@@ -534,9 +534,6 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
|
||||
if (ggml_cpu_has_dotprod()) {
|
||||
features.push_back({ "DOTPROD", "1" });
|
||||
}
|
||||
if (ggml_cpu_has_matmul_int8()) {
|
||||
features.push_back({ "MATMUL_INT8", "1" });
|
||||
}
|
||||
if (ggml_cpu_get_sve_cnt() > 0) {
|
||||
static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());
|
||||
features.push_back({ "SVE_CNT", sve_cnt.c_str() });
|
||||
|
||||
@@ -280,14 +280,6 @@ template <> inline __m256bh load(const float *p) {
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// CONSTANTS
|
||||
|
||||
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
||||
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
||||
static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// FLOATING POINT MATRIX MULTIPLICATION
|
||||
|
||||
@@ -614,6 +606,14 @@ class tinyBLAS_Q0_AVX {
|
||||
TC *C, int64_t ldc,
|
||||
int ith, int nth)
|
||||
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||
const int8_t kvalues_iq4nl[16] = {
|
||||
-127, -104, -83, -65,
|
||||
-49, -35, -22, -10,
|
||||
1, 13, 25, 38,
|
||||
53, 69, 89, 113
|
||||
};
|
||||
|
||||
iq4nlt = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
|
||||
}
|
||||
|
||||
void matmul(int64_t m, int64_t n) {
|
||||
@@ -1038,6 +1038,7 @@ class tinyBLAS_Q0_AVX {
|
||||
const int64_t ldc;
|
||||
const int ith;
|
||||
const int nth;
|
||||
__m128i iq4nlt;
|
||||
};
|
||||
#endif // __AVX__
|
||||
|
||||
|
||||
@@ -15,9 +15,9 @@ if (CUDAToolkit_FOUND)
|
||||
if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
|
||||
set(CMAKE_CUDA_ARCHITECTURES "native")
|
||||
elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
|
||||
set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75")
|
||||
set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;80")
|
||||
else()
|
||||
set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75")
|
||||
set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75;80")
|
||||
endif()
|
||||
endif()
|
||||
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
||||
|
||||
@@ -71,6 +71,47 @@
|
||||
#define GGML_CUDA_CC_QY1 210
|
||||
#define GGML_CUDA_CC_QY2 220
|
||||
|
||||
#ifdef __CUDA_ARCH_LIST__
|
||||
constexpr bool ggml_cuda_has_arch_impl(int) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template<class ... Archs>
|
||||
constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) {
|
||||
return arch == first || ggml_cuda_has_arch_impl(arch, rest...);
|
||||
}
|
||||
|
||||
constexpr bool ggml_cuda_has_arch(const int arch) {
|
||||
return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
|
||||
}
|
||||
|
||||
constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) {
|
||||
if (cur == 0) {
|
||||
GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch);
|
||||
}
|
||||
return cur;
|
||||
}
|
||||
|
||||
template<class ... Archs>
|
||||
constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) {
|
||||
if (first <= arch && first > cur) {
|
||||
return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...);
|
||||
} else {
|
||||
return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...);
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int ggml_cuda_highest_compiled_arch(const int arch) {
|
||||
return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__);
|
||||
}
|
||||
#else
|
||||
static int ggml_cuda_highest_compiled_arch(const int arch) {
|
||||
return arch;
|
||||
}
|
||||
#endif // __CUDA_ARCH_LIST__
|
||||
|
||||
// ---------------------------------------------------------------------------------------------------------
|
||||
|
||||
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
@@ -124,11 +165,11 @@ static const char * cu_get_error_str(CUresult err) {
|
||||
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
|
||||
#endif
|
||||
|
||||
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
|
||||
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
|
||||
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
||||
#else
|
||||
#define GGML_CUDA_ASSUME(x)
|
||||
#endif // CUDART_VERSION >= 11100
|
||||
#endif // CUDART_VERSION >= 11010
|
||||
|
||||
#ifdef GGML_CUDA_F16
|
||||
typedef half dfloat; // dequantize float
|
||||
@@ -162,18 +203,32 @@ typedef float2 dfloat2;
|
||||
#define FLASH_ATTN_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
|
||||
static constexpr bool fast_fp16_available(const int cc) {
|
||||
static bool fp16_available(const int cc) {
|
||||
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
|
||||
}
|
||||
|
||||
static bool fast_fp16_available(const int cc) {
|
||||
return fp16_available(cc) && cc != 610;
|
||||
}
|
||||
|
||||
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
||||
static bool fast_fp16_hardware_available(const int cc) {
|
||||
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
|
||||
}
|
||||
|
||||
// Any FP16 tensor cores are available.
|
||||
static constexpr bool fp16_mma_available(const int cc) {
|
||||
// Any FP16 tensor core instructions are available for ggml code.
|
||||
static bool fp16_mma_available(const int cc) {
|
||||
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
|
||||
}
|
||||
|
||||
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
||||
static bool fp16_mma_hardware_available(const int cc) {
|
||||
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
|
||||
}
|
||||
|
||||
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
||||
static constexpr bool new_mma_available(const int cc) {
|
||||
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
|
||||
static bool new_mma_available(const int cc) {
|
||||
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
||||
}
|
||||
|
||||
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
||||
|
||||
@@ -599,7 +599,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
||||
case GGML_TYPE_Q5_1:
|
||||
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
||||
case GGML_TYPE_Q8_0:
|
||||
if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_PASCAL) {
|
||||
if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
|
||||
return dequantize_block_q8_0_f16_cuda;
|
||||
}
|
||||
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
||||
|
||||
@@ -178,11 +178,11 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
int major_version = 0;
|
||||
size_t version_length = 0;
|
||||
if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) {
|
||||
std::string version(version_length, '\0');
|
||||
std::vector<char> version(version_length+1, '\0');
|
||||
if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) {
|
||||
version.resize(::strlen(version.c_str()));
|
||||
version.resize(::strlen(version.data()));
|
||||
int parsed_value = 0;
|
||||
if (std::from_chars(version.c_str(), version.c_str() + version.length(), parsed_value).ec == std::errc()) {
|
||||
if (std::from_chars(version.data(), version.data() + version.size(), parsed_value).ec == std::errc()) {
|
||||
major_version = parsed_value;
|
||||
}
|
||||
}
|
||||
@@ -1480,12 +1480,7 @@ static void ggml_cuda_op_mul_mat(
|
||||
const size_t nbytes_data = ggml_nbytes(src0);
|
||||
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
|
||||
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
|
||||
// TODO: remove this for MUSA once the Guilty Lockup issue is resolved
|
||||
#ifndef GGML_USE_MUSA
|
||||
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
|
||||
#else // GGML_USE_MUSA
|
||||
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));
|
||||
#endif // !GGML_USE_MUSA
|
||||
}
|
||||
|
||||
// If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
|
||||
@@ -1867,14 +1862,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
|
||||
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
||||
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
|
||||
}
|
||||
} else {
|
||||
const int cc = ggml_cuda_info().devices[ctx.device].cc;
|
||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
|
||||
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
||||
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
|
||||
}
|
||||
|
||||
// debug helpers
|
||||
@@ -2840,7 +2835,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
|
||||
return false;
|
||||
}
|
||||
|
||||
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
|
||||
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
|
||||
cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
|
||||
if (err != cudaSuccess) {
|
||||
// clear the error
|
||||
@@ -2852,8 +2847,10 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
|
||||
}
|
||||
return true;
|
||||
#else
|
||||
GGML_UNUSED(buffer);
|
||||
GGML_UNUSED(size);
|
||||
return false;
|
||||
#endif
|
||||
#endif // CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
|
||||
}
|
||||
|
||||
void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
|
||||
@@ -3205,8 +3202,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
|
||||
return true;
|
||||
}
|
||||
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
|
||||
return cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
||||
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
|
||||
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
||||
}
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
|
||||
@@ -18,7 +18,7 @@ void ggml_cuda_op_mul_mat_q(
|
||||
const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
|
||||
|
||||
int id = ggml_cuda_get_device();
|
||||
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
|
||||
// the main device has a larger memory buffer to hold the results from all GPUs
|
||||
// nrows_dst == nrows of the matrix that the kernel writes into
|
||||
@@ -27,7 +27,8 @@ void ggml_cuda_op_mul_mat_q(
|
||||
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
|
||||
// Also its fixup needs to allocate a temporary buffer in the memory pool.
|
||||
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
|
||||
const bool use_stream_k = compute_capability >= GGML_CUDA_CC_VOLTA && compute_capability < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
|
||||
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA &&
|
||||
cc < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
|
||||
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
|
||||
|
||||
switch (src0->type) {
|
||||
@@ -136,7 +137,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (cc < GGML_CUDA_CC_DP4A) {
|
||||
if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -145,8 +146,8 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
#endif //GGML_CUDA_FORCE_MMQ
|
||||
|
||||
if (cc < GGML_CUDA_CC_OFFSET_AMD) {
|
||||
return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc) && !GGML_CUDA_CC_IS_GCN(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
@@ -86,12 +86,13 @@ struct tile_x_sizes {
|
||||
int sc;
|
||||
};
|
||||
|
||||
static constexpr int get_mmq_x_max_host(const int cc) {
|
||||
static int get_mmq_x_max_host(const int cc) {
|
||||
return new_mma_available(cc) ? 128 :
|
||||
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ?
|
||||
#ifdef GGML_CUDA_FORCE_MMQ
|
||||
cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128 : 64;
|
||||
128 : 64;
|
||||
#else
|
||||
cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
|
||||
MMQ_DP4A_MAX_BATCH_SIZE : 64;
|
||||
#endif // GGML_CUDA_FORCE_MMQ
|
||||
}
|
||||
|
||||
@@ -119,8 +120,9 @@ static constexpr __device__ int get_mmq_x_max_device() {
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static constexpr int get_mmq_y_host(const int cc) {
|
||||
return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64);
|
||||
static int get_mmq_y_host(const int cc) {
|
||||
return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
|
||||
(ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? 128 : 64);
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_mmq_y_device() {
|
||||
@@ -2828,7 +2830,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
|
||||
const int mmq_x_max = get_mmq_x_max_host(cc);
|
||||
const int mmq_y = get_mmq_y_host(cc);
|
||||
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
|
||||
const bool use_stream_k = cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
|
||||
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
|
||||
|
||||
int mmq_x_best = 0;
|
||||
int nparts_best = INT_MAX;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
|
||||
#define USE_CUB
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
|
||||
|
||||
#ifdef USE_CUB
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
@@ -143,6 +143,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_rms_norm;
|
||||
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
|
||||
cl_kernel kernel_soft_max, kernel_soft_max_4;
|
||||
cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
|
||||
cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
|
||||
cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
|
||||
cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32;
|
||||
@@ -614,6 +615,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
CL_CHECK((backend_ctx->kernel_diag_mask_inf_8 = clCreateKernel(backend_ctx->program, "kernel_diag_mask_inf_8", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_soft_max = clCreateKernel(backend_ctx->program, "kernel_soft_max", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_soft_max_4 = clCreateKernel(backend_ctx->program, "kernel_soft_max_4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_soft_max_f16 = clCreateKernel(backend_ctx->program, "kernel_soft_max_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_soft_max_4_f16 = clCreateKernel(backend_ctx->program, "kernel_soft_max_4_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_rope_norm_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f32", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err));
|
||||
@@ -1044,8 +1047,16 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
return true;
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
return op->ne[3] == 1;
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE: {
|
||||
const int mode = ((const int32_t *) op->op_params)[2];
|
||||
if (mode & GGML_ROPE_TYPE_MROPE) {
|
||||
return false;
|
||||
}
|
||||
if (mode & GGML_ROPE_TYPE_VISION) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -3666,6 +3677,8 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||
|
||||
// Local size must be wave size. Each workgroup is a wave, working on a row,
|
||||
// where a row corresponds to leading dimension.
|
||||
int nth = MIN(32, ne00);
|
||||
@@ -3683,9 +3696,17 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
|
||||
cl_kernel kernel;
|
||||
|
||||
if (ne00%4 == 0) {
|
||||
kernel = backend_ctx->kernel_soft_max_4;
|
||||
if (use_f16) {
|
||||
kernel = backend_ctx->kernel_soft_max_4_f16;
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_soft_max_4;
|
||||
}
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_soft_max;
|
||||
if (use_f16) {
|
||||
kernel = backend_ctx->kernel_soft_max_f16;
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_soft_max;
|
||||
}
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
@@ -3766,7 +3787,8 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
const int nb2 = dst ? dst->nb[2] : 0;
|
||||
const int nb3 = dst ? dst->nb[3] : 0;
|
||||
|
||||
GGML_ASSERT(ne10 == ne02);
|
||||
GGML_ASSERT(ne10 % ne02 == 0);
|
||||
GGML_ASSERT(ne10 >= ne02);
|
||||
|
||||
int nth = MIN(64, ne00);
|
||||
|
||||
|
||||
@@ -679,6 +679,9 @@ kernel void kernel_diag_mask_inf_8(
|
||||
//------------------------------------------------------------------------------
|
||||
// softmax
|
||||
//------------------------------------------------------------------------------
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_soft_max(
|
||||
global float * src0,
|
||||
ulong offset0,
|
||||
@@ -811,6 +814,141 @@ kernel void kernel_soft_max_4(
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_soft_max_f16(
|
||||
global float * src0,
|
||||
ulong offset0,
|
||||
global half * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
float scale,
|
||||
float max_bias,
|
||||
float m0,
|
||||
float m1,
|
||||
int n_head_log2
|
||||
) {
|
||||
src0 = (global float *)((global char *)src0 + offset0);
|
||||
src1 = (global half *)((global char *)src1 + offset1);
|
||||
dst = (global float *)((global char *)dst + offsetd);
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
|
||||
global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0;
|
||||
global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
int h = i02;
|
||||
|
||||
float base = h < n_head_log2 ? m0 : m1;
|
||||
int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
}
|
||||
|
||||
// parallel max
|
||||
float lmax = -INFINITY;
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
||||
}
|
||||
float max = sub_group_reduce_max(lmax);
|
||||
|
||||
// parallel sum
|
||||
float lsum = 0.0f;
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max);
|
||||
lsum += exp_psrc0;
|
||||
// Remember the result of exp here. exp is expensive, so we really do not
|
||||
// wish to compute it twice.
|
||||
pdst[i00] = exp_psrc0;
|
||||
}
|
||||
|
||||
const float sum = sub_group_reduce_add(lsum);
|
||||
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
pdst[i00] /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_soft_max_4_f16(
|
||||
global float * src0,
|
||||
ulong offset0,
|
||||
global half * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
float scale,
|
||||
float max_bias,
|
||||
float m0,
|
||||
float m1,
|
||||
int n_head_log2
|
||||
) {
|
||||
src0 = (global float *)((global char *)src0 + offset0);
|
||||
src1 = (global half *)((global char *)src1 + offset1);
|
||||
dst = (global float *)((global char *)dst + offsetd);
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
|
||||
global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0;
|
||||
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
int h = i02;
|
||||
|
||||
float base = h < n_head_log2 ? m0 : m1;
|
||||
int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
}
|
||||
|
||||
// parallel max
|
||||
float4 lmax4 = -INFINITY;
|
||||
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f));
|
||||
}
|
||||
float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3));
|
||||
|
||||
const float max = sub_group_reduce_max(lmax);
|
||||
|
||||
// parallel sum
|
||||
float4 lsum4 = 0.0f;
|
||||
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)) - max);
|
||||
lsum4 += exp_psrc4;
|
||||
pdst4[i00] = exp_psrc4;
|
||||
}
|
||||
float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
|
||||
|
||||
const float sum = sub_group_reduce_add(lsum);
|
||||
|
||||
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
||||
pdst4[i00] /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// kernel_rope
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,7 +12,7 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
if (gl_LocalInvocationIndex.x != 0) {
|
||||
return;
|
||||
|
||||
@@ -217,7 +217,7 @@ void quantize(uint dst_idx, uint src_idx)
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
if (gl_LocalInvocationIndex.x != 0) {
|
||||
return;
|
||||
|
||||
@@ -88,6 +88,83 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ1_S)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const uint ib32 = iqs / 32;
|
||||
const uint ib8 = iqs / 8;
|
||||
const int i8 = int(iqs % 8);
|
||||
const uint qh = data_a[a_offset + ib].qh[ib32];
|
||||
const uint qs = data_a[a_offset + ib].qs[ib8];
|
||||
const float dl = float(2 * bitfieldExtract(qh, 12, 3) + 1);
|
||||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
const uint idxhi = bitfieldExtract(qh, 3 * int(ib8 & 3), 3);
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]);
|
||||
// Signed bitfield extract.
|
||||
const ivec2 gvec = ivec2(
|
||||
bitfieldExtract(grid, 2 * (i8), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 1), 2)
|
||||
);
|
||||
return dl * (vec2(gvec) + delta);
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const uint ib32 = iqs / 32;
|
||||
const uint ib8 = iqs / 8;
|
||||
const int i8 = int(iqs % 8);
|
||||
const uint qh = data_a[a_offset + ib].qh[ib32];
|
||||
const uint qs = data_a[a_offset + ib].qs[ib8];
|
||||
const float dl = 2 * bitfieldExtract(qh, 12, 3) + 1;
|
||||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
|
||||
// Signed bitfield extract.
|
||||
const ivec4 gvec = ivec4(
|
||||
bitfieldExtract(grid, 2 * (i8), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 1), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 2), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 3), 2)
|
||||
);
|
||||
return dl * (vec4(gvec) + delta);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ1_M)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const uint ib8 = iqs / 8;
|
||||
const uint ib16 = iqs / 16;
|
||||
const int i8 = int(iqs % 8);
|
||||
const uint sc = data_a[a_offset + ib].scales[iqs / 64];
|
||||
const uint qs = data_a[a_offset + ib].qs[ib8];
|
||||
const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1));
|
||||
const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||
// Signed bitfield extract.
|
||||
const ivec2 gvec = ivec2(
|
||||
bitfieldExtract(grid, 2 * (i8), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 1), 2)
|
||||
);
|
||||
return dl * (vec2(gvec) + delta);
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const uint ib8 = iqs / 8;
|
||||
const uint ib16 = iqs / 16;
|
||||
const int i8 = int(iqs % 8);
|
||||
const uint sc = data_a[a_offset + ib].scales[iqs / 64];
|
||||
const uint qs = data_a[a_offset + ib].qs[ib8];
|
||||
const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1));
|
||||
const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||
// Signed bitfield extract.
|
||||
const ivec4 gvec = ivec4(
|
||||
bitfieldExtract(grid, 2 * (i8), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 1), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 2), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 3), 2)
|
||||
);
|
||||
return dl * (vec4(gvec) + delta);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ2_XXS)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const uint ib32 = iqs / 32;
|
||||
@@ -357,7 +434,16 @@ vec2 get_dm(uint ib, uint a_offset) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
#if defined(DATA_A_IQ1_M)
|
||||
vec2 get_dm(uint ib, uint a_offset) {
|
||||
const uint16_t[4] scales = data_a[a_offset + ib].scales;
|
||||
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
||||
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
|
||||
return vec2(d, 0);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
vec2 get_dm(uint ib, uint a_offset) {
|
||||
return vec2(float(data_a[a_offset + ib].d), 0);
|
||||
}
|
||||
|
||||
@@ -301,6 +301,56 @@ float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2
|
||||
return ret;
|
||||
}
|
||||
|
||||
#if defined(DATA_A_IQ1_S)
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S {
|
||||
block_iq1_s block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint ib32 = idx / 32;
|
||||
const uint ib8 = idx / 8;
|
||||
|
||||
const uint qh = bl.block.qh[ib32];
|
||||
const uint qs = bl.block.qs[ib8];
|
||||
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
|
||||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)];
|
||||
|
||||
float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta));
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ1_M)
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_M {
|
||||
block_iq1_m block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const u16vec4 scales = u16vec4(bl.block.scales[0], bl.block.scales[1], bl.block.scales[2], bl.block.scales[3]) >> 12;
|
||||
const float16_t d = uint16BitsToHalf(scales.x | (scales.y << 4) | (scales.z << 8) | (scales.w << 12));
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint ib8 = idx / 8;
|
||||
const uint ib16 = idx / 16;
|
||||
const int i8 = int(idx % 8);
|
||||
const uint sc = bl.block.scales[ib8 / 8];
|
||||
const uint qs = bl.block.qs[ib8];
|
||||
const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1));
|
||||
const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
const uint grid = iq1s_grid[qs | ((qh & 7) << 8)];
|
||||
|
||||
float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta));
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ2_XXS)
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS {
|
||||
block_iq2_xxs block;
|
||||
@@ -512,6 +562,10 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
|
||||
#define dequantFuncA dequantFuncQ5_K
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
#define dequantFuncA dequantFuncQ6_K
|
||||
#elif defined(DATA_A_IQ1_S)
|
||||
#define dequantFuncA dequantFuncIQ1_S
|
||||
#elif defined(DATA_A_IQ1_M)
|
||||
#define dequantFuncA dequantFuncIQ1_M
|
||||
#elif defined(DATA_A_IQ2_XXS)
|
||||
#define dequantFuncA dequantFuncIQ2_XXS
|
||||
#elif defined(DATA_A_IQ2_XS)
|
||||
|
||||
42
ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp
Normal file
42
ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp
Normal file
@@ -0,0 +1,42 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
|
||||
#include "dequant_head.comp"
|
||||
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {block_iq1_m data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
|
||||
void main() {
|
||||
// Each thread handles 1 subblock (32 values with 2 scales)
|
||||
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
|
||||
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
|
||||
if (ib >= p.nel / 256) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint ib32 = gl_LocalInvocationID.x % 8;
|
||||
const uint ib64 = ib32 / 2;
|
||||
const uint b_idx = 256 * ib + 32 * ib32;
|
||||
|
||||
const uint16_t[4] scales = data_a[ib].scales;
|
||||
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
||||
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
|
||||
|
||||
const uint sc = data_a[ib].scales[ib64];
|
||||
[[unroll]] for (int l = 0; l < 4; ++l) {
|
||||
const uint ib16 = 2 * ib32 + l / 2;
|
||||
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
|
||||
const uint qh = data_a[ib].qh[ib16] >> (4 * (l & 1));
|
||||
const uint qs = data_a[ib].qs[4 * ib32 + l];
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||
[[unroll]] for (int j = 0; j < 8; ++j) {
|
||||
data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta));
|
||||
}
|
||||
}
|
||||
}
|
||||
35
ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp
Normal file
35
ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp
Normal file
@@ -0,0 +1,35 @@
|
||||
#version 450
|
||||
|
||||
#include "dequant_head.comp"
|
||||
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {block_iq1_s data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
|
||||
void main() {
|
||||
// Each thread handles 1 subblock (32 values with 2 scales)
|
||||
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
|
||||
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
|
||||
if (ib >= p.nel / 256) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint ib32 = gl_LocalInvocationID.x % 8;
|
||||
const uint b_idx = 256 * ib + 32 * ib32;
|
||||
|
||||
uint qh = data_a[ib].qh[ib32];
|
||||
const float d = float(data_a[ib].d);
|
||||
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
|
||||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||
const uint qs = data_a[ib].qs[4 * ib32 + l];
|
||||
const uint hi = bitfieldExtract(qh, 3 * int(l), 3);
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | (hi << 8)]);
|
||||
[[unroll]] for (int j = 0; j < 8; ++j) {
|
||||
data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -104,7 +104,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ void main() {
|
||||
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
|
||||
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
|
||||
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
|
||||
@@ -133,7 +133,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
void main() {
|
||||
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
|
||||
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
|
||||
82
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp
Normal file
82
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp
Normal file
@@ -0,0 +1,82 @@
|
||||
#version 450
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#include "mul_mat_vec_base.comp"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||
|
||||
void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
|
||||
const uint y_idx = i * QUANT_K + 32 * ib32;
|
||||
|
||||
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint16_t[4] scales = data_a[ibi].scales;
|
||||
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
||||
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
|
||||
|
||||
const uint sc = data_a[ibi].scales[ib32 / 2] >> (6 * (ib32 & 1));
|
||||
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||
const uint qh = data_a[ibi].qh[2 * ib32 + l / 2] >> (4 * (l&1));
|
||||
const uint qs = data_a[ibi].qs[4 * ib32 + l];
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1);
|
||||
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
|
||||
vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
||||
[[unroll]] for (int k = 0; k < 4; ++k) {
|
||||
sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta,
|
||||
fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum));
|
||||
}
|
||||
temp[j][n] = fma(dl, sum, temp[j][n]);
|
||||
}
|
||||
}
|
||||
ibi += num_blocks_per_row;
|
||||
}
|
||||
}
|
||||
|
||||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
uint a_offset, b_offset, d_offset;
|
||||
get_offsets(a_offset, b_offset, d_offset);
|
||||
|
||||
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
||||
|
||||
// 8 threads are used to process each block
|
||||
const uint blocks_per_wg = gl_WorkGroupSize.x/8;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint itid = tid % 8; // 0...7
|
||||
const uint ix = tid / 8;
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
||||
temp[j][i] = FLOAT_TYPE(0);
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
|
||||
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
|
||||
|
||||
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
|
||||
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
|
||||
// do NUM_ROWS at a time, unless there aren't enough remaining rows
|
||||
if (first_row + NUM_ROWS <= p.stride_d) {
|
||||
compute_outputs(first_row, NUM_ROWS);
|
||||
} else {
|
||||
if (first_row >= p.stride_d) {
|
||||
return;
|
||||
}
|
||||
compute_outputs(first_row, p.stride_d - first_row);
|
||||
}
|
||||
}
|
||||
79
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp
Normal file
79
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp
Normal file
@@ -0,0 +1,79 @@
|
||||
#version 450
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#include "mul_mat_vec_base.comp"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||
|
||||
void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
|
||||
const uint y_idx = i * QUANT_K + 32 * ib32;
|
||||
|
||||
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const float d = float(data_a[ibi].d);
|
||||
const uint qh = data_a[ibi].qh[ib32];
|
||||
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
|
||||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
|
||||
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||
const uint qs = data_a[ibi].qs[4 * ib32 + l];
|
||||
const uint idxhi = bitfieldExtract(qh, 3 * int(l), 3);
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]);
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
|
||||
vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
||||
[[unroll]] for (int k = 0; k < 4; ++k) {
|
||||
sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta,
|
||||
fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum));
|
||||
}
|
||||
temp[j][n] = fma(dl, sum, temp[j][n]);
|
||||
}
|
||||
}
|
||||
ibi += num_blocks_per_row;
|
||||
}
|
||||
}
|
||||
|
||||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
uint a_offset, b_offset, d_offset;
|
||||
get_offsets(a_offset, b_offset, d_offset);
|
||||
|
||||
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
||||
|
||||
// 8 threads are used to process each block
|
||||
const uint blocks_per_wg = gl_WorkGroupSize.x/8;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint itid = tid % 8; // 0...7
|
||||
const uint ix = tid / 8;
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
||||
temp[j][i] = FLOAT_TYPE(0);
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
|
||||
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
|
||||
|
||||
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
|
||||
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
|
||||
// do NUM_ROWS at a time, unless there aren't enough remaining rows
|
||||
if (first_row + NUM_ROWS <= p.stride_d) {
|
||||
compute_outputs(first_row, NUM_ROWS);
|
||||
} else {
|
||||
if (first_row >= p.stride_d) {
|
||||
return;
|
||||
}
|
||||
compute_outputs(first_row, p.stride_d - first_row);
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,9 @@
|
||||
#ifdef FLOAT16
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#endif
|
||||
#if defined(DATA_A_IQ1_M)
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
@@ -95,7 +98,7 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
@@ -437,6 +440,56 @@ void main() {
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
|
||||
#elif defined(DATA_A_IQ1_S)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
const uint ib8 = (idx % 128) / 4;
|
||||
const int i8 = 2 * int(idx % 4);
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qh = data_a[ib].qh[ib32];
|
||||
const uint qs = data_a[ib].qs[ib8];
|
||||
const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1);
|
||||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
|
||||
|
||||
const ivec2 gvec = ivec2(
|
||||
bitfieldExtract(grid, 2 * (i8), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 1), 2)
|
||||
);
|
||||
const vec2 v = dl * (vec2(gvec) + delta);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
#elif defined(DATA_A_IQ1_M)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib8 = (idx % 128) / 4;
|
||||
const uint ib16 = ib8 / 2;
|
||||
const int i8 = 2 * int(idx % 4);
|
||||
|
||||
const uint16_t[4] scales = data_a[ib].scales;
|
||||
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
||||
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
|
||||
const uint sc = scales[ib8 / 8];
|
||||
const uint qs = data_a[ib].qs[ib8];
|
||||
const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1));
|
||||
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||
const ivec2 gvec = ivec2(
|
||||
bitfieldExtract(grid, 2 * (i8), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 1), 2)
|
||||
);
|
||||
const vec2 v = dl * (vec2(gvec) + delta);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
#elif defined(DATA_A_IQ2_XXS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
@@ -106,7 +106,7 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
|
||||
@@ -294,6 +294,187 @@ struct block_q6_K_packed16
|
||||
|
||||
// IQuants
|
||||
|
||||
#define QUANT_K_IQ1_S 256
|
||||
#define QUANT_R_IQ1_S 1
|
||||
|
||||
struct block_iq1_s {
|
||||
float16_t d;
|
||||
uint8_t qs[QUANT_K_IQ1_S/8];
|
||||
uint16_t qh[QUANT_K_IQ1_S/32];
|
||||
};
|
||||
|
||||
#define QUANT_K_IQ1_M 256
|
||||
#define QUANT_R_IQ1_M 1
|
||||
|
||||
struct block_iq1_m {
|
||||
uint8_t qs[QUANT_K_IQ1_M/8];
|
||||
uint8_t qh[QUANT_K_IQ1_M/16];
|
||||
uint16_t scales[QUANT_K_IQ1_M/64];
|
||||
};
|
||||
|
||||
#if defined(DATA_A_IQ1_S)
|
||||
#define QUANT_K QUANT_K_IQ1_S
|
||||
#define QUANT_R QUANT_R_IQ1_S
|
||||
#define A_TYPE block_iq1_s
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ1_M)
|
||||
#define QUANT_K QUANT_K_IQ1_M
|
||||
#define QUANT_R QUANT_R_IQ1_M
|
||||
#define A_TYPE block_iq1_m
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M)
|
||||
#define IQ1S_DELTA 0.125f
|
||||
#define IQ1M_DELTA 0.125f
|
||||
|
||||
// Packed IQ1S grid where every 2 vec8 are encoded on 32 bits (2 bits per coordinate).
|
||||
const uint[1024] iq1s_grid_const = {
|
||||
0xfffdffff, 0xfff7fff0, 0xffccfff5, 0xffdfffc0, 0xffd7ffdd, 0xff30ffd5, 0xff03ff0c, 0xff10ff01,
|
||||
0xff7dff7f, 0xff75ff77, 0xff5fff40, 0xff57ff5d, 0xfcf3ff55, 0xfcccfcf0, 0xfcc1fcc3, 0xfcc5fcc4,
|
||||
0xfc3cfcd0, 0xfc34fc31, 0xfc00fc0d, 0xfc1cfc05, 0xfc11fc13, 0xfc70fc17, 0xfc43fc4c, 0xfc50fc41,
|
||||
0xfdfdfdff, 0xfdf5fdf7, 0xfddffdc0, 0xfdd7fddd, 0xfd30fdd5, 0xfd04fd0c, 0xfd14fd13, 0xfd7dfd7f,
|
||||
0xfd75fd77, 0xfd40fd4c, 0xfd5ffd44, 0xfd57fd5d, 0xf3ccfd55, 0xf3c1f3c3, 0xf33cf3d0, 0xf300f334,
|
||||
0xf313f305, 0xf34cf310, 0xf350f344, 0xf0f3f0fc, 0xf0f1f0f0, 0xf0c7f0c0, 0xf0d4f0c5, 0xf030f03f,
|
||||
0xf00ff035, 0xf003f00c, 0xf001f000, 0xf01ff004, 0xf010f01d, 0xf015f017, 0xf04cf07c, 0xf047f040,
|
||||
0xf05cf045, 0xf050f053, 0xf054f051, 0xf1c4f1c3, 0xf133f13c, 0xf10df10f, 0xf107f100, 0xf11cf11f,
|
||||
0xf114f111, 0xf14cf170, 0xf144f143, 0xf7fdf7ff, 0xf7f5f7f7, 0xf7dff7c0, 0xf7d7f7dd, 0xf730f7d5,
|
||||
0xf701f70c, 0xf77ff710, 0xf777f77d, 0xf740f775, 0xf75df75f, 0xf755f757, 0xf4ccf4f0, 0xf4c4f4c3,
|
||||
0xf4d0f4d3, 0xf40ff43c, 0xf400f40c, 0xf413f41c, 0xf44cf414, 0xf441f443, 0xf450f444, 0xf5fdf5ff,
|
||||
0xf5f5f5f7, 0xf5dff5c0, 0xf5d7f5dd, 0xf530f5d5, 0xf504f50c, 0xf510f51c, 0xf57df57f, 0xf577f570,
|
||||
0xf540f575, 0xf55df55f, 0xf555f557, 0xcfcccfcf, 0xcfc4cfc3, 0xcfd0cfd3, 0xcf33cf3c, 0xcf00cf0f,
|
||||
0xcf1ccf07, 0xcf10cf13, 0xcf4ccf14, 0xcf41cf43, 0xcf50cf5c, 0xccf3ccfc, 0xccf4ccf1, 0xcccdcccf,
|
||||
0xccc7ccc0, 0xccd3ccdc, 0xcc30ccd4, 0xcc0fcc35, 0xcc0dcc0c, 0xcc00cc03, 0xcc04cc01, 0xcc10cc1f,
|
||||
0xcc4dcc73, 0xcc5ccc40, 0xcdcccc53, 0xcdc1cdc3, 0xcd3fcdd0, 0xcd34cd31, 0xcd00cd0d, 0xcd05cd07,
|
||||
0xcd11cd13, 0xcd4ccd70, 0xcd41cd43, 0xc3fccd50, 0xc3f4c3f1, 0xc3c0c3c3, 0xc3c4c3c7, 0xc3d1c3dc,
|
||||
0xc330c33c, 0xc337c331, 0xc30cc335, 0xc300c303, 0xc304c301, 0xc310c31d, 0xc373c317, 0xc34fc374,
|
||||
0xc340c343, 0xc344c347, 0xc35cc345, 0xc350c353, 0xc0fdc354, 0xc0f5c0f0, 0xc0c3c0cc, 0xc0c1c0c0,
|
||||
0xc0dfc0c4, 0xc0d0c0dd, 0xc0d5c0d7, 0xc033c03c, 0xc031c030, 0xc00dc00c, 0xc000c003, 0xc004c001,
|
||||
0xc01cc005, 0xc010c013, 0xc014c011, 0xc07dc07f, 0xc070c073, 0xc075c077, 0xc04cc04f, 0xc040c043,
|
||||
0xc044c041, 0xc05fc045, 0xc050c05d, 0xc1f3c1fc, 0xc1f1c1f0, 0xc1c1c1c0, 0xc1c5c1c7, 0xc1d1c1dc,
|
||||
0xc13dc13f, 0xc130c133, 0xc135c137, 0xc100c10c, 0xc107c101, 0xc11cc104, 0xc110c113, 0xc114c117,
|
||||
0xc171c115, 0xc14dc175, 0xc153c140, 0xc7ccc154, 0xc7d0c7c1, 0xc733c73c, 0xc734c731, 0xc700c70f,
|
||||
0xc705c707, 0xc71cc71f, 0xc711c713, 0xc770c714, 0xc743c74c, 0xc4cfc750, 0xc4c0c4cd, 0xc4dcc4c5,
|
||||
0xc43dc4d0, 0xc430c433, 0xc40cc437, 0xc400c403, 0xc404c401, 0xc41fc405, 0xc415c410, 0xc44cc474,
|
||||
0xc440c44d, 0xc45cc447, 0xc454c451, 0xc5c1c5f4, 0xc5d1c5d3, 0xc531c533, 0xc50fc534, 0xc500c50d,
|
||||
0xc51cc507, 0xc514c511, 0xc54cc570, 0xc545c541, 0xdffddfff, 0xdff5dff7, 0xdfdfdfc0, 0xdfd0dfdd,
|
||||
0xdfd5dfd7, 0xdf0cdf30, 0xdf1cdf04, 0xdf7fdf10, 0xdf77df7d, 0xdf40df75, 0xdf5ddf5f, 0xdf57df50,
|
||||
0xdcf0df55, 0xdcc3dccc, 0xdcd0dcc4, 0xdc33dc3d, 0xdc00dc34, 0xdc05dc07, 0xdc13dc1c, 0xdc11dc10,
|
||||
0xdc4fdc70, 0xdc44dc41, 0xddfcdc50, 0xddf5ddf7, 0xddc0ddcc, 0xdddddddf, 0xddd5ddd7, 0xdd0cdd30,
|
||||
0xdd04dd01, 0xdd7cdd10, 0xdd75dd77, 0xdd40dd4c, 0xdd5ddd5f, 0xdd55dd57, 0xd3c3d3f0, 0xd3c4d3c1,
|
||||
0xd333d3d0, 0xd331d330, 0xd30dd334, 0xd307d300, 0xd311d305, 0xd34cd370, 0xd344d343, 0xd350d35c,
|
||||
0xd0c0d0f4, 0xd0d4d0dc, 0xd030d03f, 0xd00cd037, 0xd000d003, 0xd01dd004, 0xd017d010, 0xd04fd074,
|
||||
0xd040d043, 0xd045d047, 0xd053d05c, 0xd054d051, 0xd1cfd1f0, 0xd1c4d1cd, 0xd13cd1d0, 0xd100d134,
|
||||
0xd11cd11f, 0xd173d114, 0xd14fd171, 0xd7ffd145, 0xd7f7d7fd, 0xd7c0d7f5, 0xd7ddd7df, 0xd7d5d7d7,
|
||||
0xd70cd730, 0xd710d703, 0xd77dd77f, 0xd775d777, 0xd75dd75f, 0xd755d757, 0xd4ccd4f4, 0xd4c4d4c3,
|
||||
0xd431d4d0, 0xd40dd434, 0xd41cd400, 0xd411d413, 0xd470d414, 0xd441d44f, 0xd453d444, 0xd5ffd450,
|
||||
0xd5f7d5fd, 0xd5dfd5f5, 0xd5d7d5dd, 0xd530d5d5, 0xd501d50c, 0xd510d504, 0xd57dd57f, 0xd575d577,
|
||||
0xd55fd540, 0xd557d55d, 0x3ff0d555, 0x3fc13fcc, 0x3f343fd0, 0x3f003f0d, 0x3f053f07, 0x3f133f1c,
|
||||
0x3f433f11, 0x3f5c3f44, 0x3cff3f51, 0x3cf33cfc, 0x3cf43cf1, 0x3cc03ccd, 0x3cc73cc1, 0x3cdc3cc5,
|
||||
0x3cd43cd1, 0x3c373c30, 0x3c0c3c35, 0x3c003c03, 0x3c043c01, 0x3c103c05, 0x3c153c17, 0x3c733c7c,
|
||||
0x3c4f3c71, 0x3c403c4d, 0x3c5c3c5f, 0x3df03c5d, 0x3dc33dcc, 0x3dd03dc1, 0x3d0d3d3c, 0x3d053d00,
|
||||
0x3d143d13, 0x3d433d74, 0x33fc3d50, 0x33c433c0, 0x333033d4, 0x33353337, 0x3303330c, 0x33013300,
|
||||
0x331d331c, 0x33173310, 0x337c3315, 0x33743371, 0x334d334f, 0x335f3340, 0x3354335c, 0x30fd30fc,
|
||||
0x30f530f0, 0x30c330cc, 0x30c130c0, 0x30df30c4, 0x30d530d0, 0x3033303c, 0x30313030, 0x300f3034,
|
||||
0x3003300c, 0x30013000, 0x30043007, 0x3013301c, 0x30113010, 0x307d3014, 0x30703073, 0x304c3077,
|
||||
0x30403043, 0x30443041, 0x30503045, 0x30553057, 0x31f031fc, 0x31c331f4, 0x31c731c0, 0x31dc31c5,
|
||||
0x31d431d3, 0x313d313f, 0x31373130, 0x310c310f, 0x3100310d, 0x31043101, 0x3110311d, 0x317c3117,
|
||||
0x31753170, 0x31403143, 0x3153315c, 0x37f03151, 0x37c037cc, 0x37d037c5, 0x3734373d, 0x3700370f,
|
||||
0x371c3707, 0x37113713, 0x37703714, 0x3743374c, 0x37443741, 0x34fc3750, 0x34f134f0, 0x34cf34f5,
|
||||
0x34c034c3, 0x34dc34c7, 0x34d134d3, 0x3430343f, 0x340c3435, 0x3403340d, 0x34013400, 0x341f3404,
|
||||
0x3410341d, 0x34153411, 0x34743471, 0x3440344d, 0x34473441, 0x3453345c, 0x34543451, 0x353335c1,
|
||||
0x35343531, 0x35073500, 0x35133505, 0x35433514, 0x0ffc3550, 0x0ff00ff3, 0x0ff40ff1, 0x0fc00fcd,
|
||||
0x0fdc0fc5, 0x0fd40fd3, 0x0f300f3f, 0x0f0c0f37, 0x0f000f03, 0x0f040f01, 0x0f170f10, 0x0f740f71,
|
||||
0x0f470f40, 0x0f5c0f5f, 0x0f540f51, 0x0cf70cf0, 0x0cf50cf4, 0x0cc30ccc, 0x0cc10cc0, 0x0cc40cc7,
|
||||
0x0cd00cdf, 0x0cd70cd1, 0x0c3c0cd5, 0x0c300c33, 0x0c340c31, 0x0c0c0c0f, 0x0c030c0d, 0x0c010c00,
|
||||
0x0c040c07, 0x0c1c0c05, 0x0c100c13, 0x0c140c11, 0x0c700c7d, 0x0c430c4c, 0x0c410c40, 0x0c5f0c44,
|
||||
0x0c550c50, 0x0df10dfc, 0x0dc00dcd, 0x0ddc0dc5, 0x0d3d0dd3, 0x0d350d30, 0x0d030d0c, 0x0d010d00,
|
||||
0x0d1d0d04, 0x0d700d10, 0x0d4d0d4f, 0x0d440d40, 0x0d530d45, 0x03f003f3, 0x03c303cc, 0x03c103c0,
|
||||
0x03c403c7, 0x03d003dc, 0x03d503d7, 0x0333033c, 0x03310330, 0x03350334, 0x030c030f, 0x03000303,
|
||||
0x03070301, 0x03050304, 0x031d031c, 0x03100313, 0x03140311, 0x0377037f, 0x034c0375, 0x03400343,
|
||||
0x03440341, 0x0353035c, 0x03550350, 0x00fd00fc, 0x00f000f3, 0x00f400f1, 0x00cc00cf, 0x00c300cd,
|
||||
0x00c100c0, 0x00c500c4, 0x00d300dc, 0x00d100d0, 0x003f00d4, 0x003d003c, 0x00300033, 0x00370031,
|
||||
0x000f0034, 0x000d000c, 0x00000003, 0x00070001, 0x00050004, 0x001c001f, 0x00100013, 0x00170011,
|
||||
0x00150014, 0x0073007c, 0x00740070, 0x004f0075, 0x0043004c, 0x00410040, 0x00440047, 0x0053005c,
|
||||
0x00510050, 0x01ff0054, 0x01fd01fc, 0x01f101f3, 0x01f401f7, 0x01c301cc, 0x01c701c0, 0x01df01c4,
|
||||
0x01dd01dc, 0x01d001d3, 0x01d701d1, 0x013c01d4, 0x01310130, 0x01340137, 0x010f0135, 0x010d010c,
|
||||
0x01000103, 0x01070101, 0x01050104, 0x0113011c, 0x01140110, 0x0170017d, 0x01770171, 0x01750174,
|
||||
0x0140014c, 0x015d0145, 0x01510150, 0x01540157, 0x07f007f3, 0x07f407f1, 0x07c007cf, 0x07dc07c7,
|
||||
0x073007d5, 0x07350737, 0x0703070c, 0x07010700, 0x07040707, 0x071d071f, 0x07100713, 0x0774077d,
|
||||
0x074d074f, 0x07470740, 0x0754075c, 0x04fd04fc, 0x04f504f0, 0x04c304cc, 0x04c104c0, 0x04d004c4,
|
||||
0x0433043c, 0x04310430, 0x040f0434, 0x040d040c, 0x04000403, 0x04070401, 0x04050404, 0x0413041c,
|
||||
0x04110410, 0x047c0414, 0x04740470, 0x0443044c, 0x04410440, 0x04440447, 0x05f30450, 0x05c005f7,
|
||||
0x05df05c5, 0x05d105d0, 0x053005d4, 0x05340537, 0x0500050c, 0x05070501, 0x051d0504, 0x05170510,
|
||||
0x057c0515, 0x054d0575, 0x05410540, 0x05450547, 0x1ff0055c, 0x1fc11fc3, 0x1fd01fc4, 0x1f0f1f33,
|
||||
0x1f011f00, 0x1f051f07, 0x1f131f1c, 0x1f141f11, 0x1f411f7c, 0x1cfc1f50, 0x1cf11cf3, 0x1ccd1cf4,
|
||||
0x1cdc1cc0, 0x1cd11cdd, 0x1c301cd4, 0x1c0c1c34, 0x1c011c00, 0x1c101c04, 0x1c151c11, 0x1c751c73,
|
||||
0x1c401c4d, 0x1c511c5c, 0x1dcc1c54, 0x1dc41dc1, 0x1d3c1d3f, 0x1d001d31, 0x1d071d01, 0x1d701d1f,
|
||||
0x1d411d4c, 0x13cc1d50, 0x13c013cd, 0x13c513c1, 0x13d113dc, 0x133f13d4, 0x1330133d, 0x13351337,
|
||||
0x1303130c, 0x13011300, 0x13051304, 0x131d131f, 0x13731310, 0x13741370, 0x134d134f, 0x13401343,
|
||||
0x13471341, 0x135c1345, 0x13541353, 0x10f710f0, 0x10cc10f5, 0x10c110c0, 0x103310c4, 0x10311030,
|
||||
0x100f1034, 0x1003100c, 0x10011000, 0x101c1004, 0x10101013, 0x10141011, 0x10741071, 0x104c1075,
|
||||
0x10411040, 0x10451044, 0x1050105d, 0x10571051, 0x11f411fd, 0x11df11c0, 0x11d711d1, 0x113f11d4,
|
||||
0x11371130, 0x110c1135, 0x11001103, 0x11071101, 0x111f1105, 0x11171110, 0x117d117f, 0x11751170,
|
||||
0x11411143, 0x11441147, 0x1153115f, 0x11551151, 0x17c417c1, 0x173c17d0, 0x1700170d, 0x171c1705,
|
||||
0x17701714, 0x1747174c, 0x14fc1751, 0x14cf14f3, 0x14dc14c0, 0x14d114d3, 0x143f14d4, 0x1430143c,
|
||||
0x14371431, 0x1403140c, 0x14011400, 0x141f1404, 0x14151410, 0x1473147d, 0x14401475, 0x1453145c,
|
||||
0x14541450, 0x15c115cc, 0x153c15c7, 0x15341533, 0x1500150f, 0x15051507, 0x15101513, 0x15711514,
|
||||
0x15471543, 0x15511545, 0x7ffd7fff, 0x7ff57ff7, 0x7fdd7fdf, 0x7fd57fd7, 0x7f0f7f30, 0x7f037f0c,
|
||||
0x7f047f01, 0x7f7f7f10, 0x7f777f7d, 0x7f407f75, 0x7f5d7f5f, 0x7f557f57, 0x7ccc7cf0, 0x7cc17cc3,
|
||||
0x7cd07cc4, 0x7c337c3c, 0x7c0f7c34, 0x7c007c0d, 0x7c077c01, 0x7c137c04, 0x7c147c11, 0x7c747c70,
|
||||
0x7c417c43, 0x7c507c44, 0x7dfd7dff, 0x7df57df7, 0x7ddf7dc0, 0x7dd77ddd, 0x7d0c7dd5, 0x7d047d03,
|
||||
0x7d7f7d10, 0x7d777d7d, 0x7d407d75, 0x7d5d7d5f, 0x7d557d57, 0x73c473c3, 0x7333733c, 0x7300730c,
|
||||
0x731c7305, 0x73147313, 0x73447343, 0x70f470fc, 0x70c070cd, 0x70d170c5, 0x703f70d4, 0x7030703c,
|
||||
0x700c7037, 0x70007003, 0x70047001, 0x70107005, 0x70177011, 0x707c7015, 0x70717073, 0x704f7074,
|
||||
0x7040704d, 0x70517047, 0x71c171cc, 0x71d071c4, 0x7133713c, 0x71357134, 0x7100710f, 0x71057104,
|
||||
0x7111711c, 0x71707115, 0x7145714c, 0x77ff7153, 0x77f777fd, 0x77c077f5, 0x77dd77df, 0x77d577d7,
|
||||
0x7730773c, 0x7703770c, 0x77107704, 0x777f7714, 0x7777777d, 0x77407775, 0x775d775f, 0x77557757,
|
||||
0x74f174f0, 0x74c374cc, 0x74d074c1, 0x7433743c, 0x74347431, 0x740d740f, 0x74057400, 0x7413741c,
|
||||
0x74417470, 0x74507444, 0x75fd75ff, 0x75f575f7, 0x75df75c0, 0x75d775dd, 0x753075d5, 0x7503750c,
|
||||
0x757f7501, 0x7577757d, 0x75407575, 0x755d755f, 0x75557557, 0x4fcc4ff0, 0x4fc74fc1, 0x4fd04fc4,
|
||||
0x4f314f3c, 0x4f004f34, 0x4f054f07, 0x4f154f14, 0x4f4c4f70, 0x4f414f43, 0x4f504f44, 0x4cf34cfc,
|
||||
0x4cf44cf1, 0x4cc04ccf, 0x4cc54cc7, 0x4cd34cdc, 0x4cd44cd1, 0x4c304c3f, 0x4c0c4c0f, 0x4c004c03,
|
||||
0x4c044c01, 0x4c104c1d, 0x4c714c73, 0x4c404c4d, 0x4c5c4c47, 0x4c514c53, 0x4df04c54, 0x4dc34dcc,
|
||||
0x4dd04dc4, 0x4d314d33, 0x4d0f4d34, 0x4d004d0d, 0x4d114d07, 0x4d704d14, 0x4d414d43, 0x43fc4d54,
|
||||
0x43f143f3, 0x43c043cf, 0x43d143c7, 0x4335433f, 0x4303430c, 0x43014300, 0x43044307, 0x431c431f,
|
||||
0x4310431d, 0x43714373, 0x4343434d, 0x43474340, 0x4354435c, 0x40f040ff, 0x40f540f7, 0x40cc40cf,
|
||||
0x40c040c3, 0x40c440c1, 0x40d040dc, 0x40d540d4, 0x4033403c, 0x40314030, 0x400f4034, 0x400d400c,
|
||||
0x40004003, 0x40074001, 0x40054004, 0x4013401c, 0x40114010, 0x407c4014, 0x40774070, 0x404d404c,
|
||||
0x40404043, 0x40444041, 0x405f4045, 0x4050405d, 0x40554057, 0x41f341fc, 0x41c041cf, 0x41df41c4,
|
||||
0x41d441d1, 0x41374130, 0x410c4134, 0x4100410d, 0x41044101, 0x41174110, 0x4173417d, 0x41754174,
|
||||
0x4143414d, 0x41534140, 0x41544151, 0x47c147f0, 0x47d047c4, 0x4731473c, 0x470d470f, 0x47014700,
|
||||
0x47134705, 0x47704710, 0x4741474c, 0x47504744, 0x44f144f3, 0x44cf44f4, 0x44c044cd, 0x44c544c7,
|
||||
0x44dc44df, 0x44d144d3, 0x443d443f, 0x44374430, 0x440c4435, 0x44004403, 0x44044401, 0x4410441d,
|
||||
0x44154411, 0x4473447c, 0x444d444f, 0x44454440, 0x4451445c, 0x45c045f0, 0x453345d0, 0x45344531,
|
||||
0x4500450f, 0x451c4507, 0x454c4570, 0x45404543, 0x5fff4541, 0x5ff75ffd, 0x5fc05ff5, 0x5fdd5fdf,
|
||||
0x5fd55fd7, 0x5f0c5f30, 0x5f015f03, 0x5f7f5f04, 0x5f775f7d, 0x5f405f75, 0x5f5d5f5f, 0x5f555f57,
|
||||
0x5cf45cf0, 0x5cc35ccc, 0x5cc45cc1, 0x5c315cc5, 0x5c0c5c34, 0x5c075c00, 0x5c1c5c05, 0x5c705c13,
|
||||
0x5c4d5c4f, 0x5c445c41, 0x5df75dfd, 0x5dcf5df5, 0x5ddd5dc4, 0x5dd55dd7, 0x5d0c5d30, 0x5d045d01,
|
||||
0x5d7f5d10, 0x5d775d7d, 0x5d405d75, 0x5d5d5d5f, 0x5d555d57, 0x53d053c4, 0x5333533c, 0x5303530f,
|
||||
0x53075300, 0x531c5305, 0x53115310, 0x53145317, 0x50f15370, 0x50cf50f4, 0x50c050cd, 0x50d150c7,
|
||||
0x503d50d4, 0x500c5030, 0x50005003, 0x50045001, 0x50155010, 0x5073507c, 0x50715070, 0x504d5074,
|
||||
0x50475040, 0x51cc51f0, 0x51c551c1, 0x51d051dc, 0x51315133, 0x510d5135, 0x51015100, 0x511f5107,
|
||||
0x5171511d, 0x5140514f, 0x51445141, 0x5153515c, 0x57ff5151, 0x57f757fd, 0x57df57f5, 0x57d757dd,
|
||||
0x570c57d5, 0x57015703, 0x577f5704, 0x5777577d, 0x57405775, 0x575d575f, 0x57555757, 0x54c354f0,
|
||||
0x54dc54c4, 0x543c54d0, 0x5400540f, 0x541c5405, 0x54145411, 0x5441544f, 0x55fd55ff, 0x55f555f7,
|
||||
0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
|
||||
};
|
||||
|
||||
shared uint16_t iq1s_grid[2048];
|
||||
|
||||
#define NEEDS_INIT_IQ_SHMEM
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
for (uint i = gl_LocalInvocationIndex.x; i < iq1s_grid_const.length(); i += wgsize.x) {
|
||||
u16vec2 g = unpack16(iq1s_grid_const[i]);
|
||||
iq1s_grid[2*i+0] = g.x;
|
||||
iq1s_grid[2*i+1] = g.y;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
#endif
|
||||
|
||||
#define QUANT_K_IQ2_XXS 256
|
||||
#define QUANT_R_IQ2_XXS 1
|
||||
|
||||
@@ -380,6 +561,7 @@ const uvec2[256] iq2xxs_grid_const = {
|
||||
|
||||
shared uvec2 iq2xxs_grid[256];
|
||||
|
||||
#define NEEDS_INIT_IQ_SHMEM
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
@@ -547,6 +729,7 @@ const uvec2 iq2xs_grid_const[512] = {
|
||||
|
||||
shared uvec2 iq2xs_grid[512];
|
||||
|
||||
#define NEEDS_INIT_IQ_SHMEM
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
@@ -836,6 +1019,7 @@ const uvec2 iq2s_grid_const[1024] = {
|
||||
|
||||
shared uvec2 iq2s_grid[1024];
|
||||
|
||||
#define NEEDS_INIT_IQ_SHMEM
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
@@ -904,6 +1088,7 @@ const uint32_t iq3xxs_grid_const[256] = {
|
||||
|
||||
shared uint32_t iq3xxs_grid[256];
|
||||
|
||||
#define NEEDS_INIT_IQ_SHMEM
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
@@ -1011,6 +1196,7 @@ const uint32_t iq3s_grid_const[512] = {
|
||||
|
||||
shared uint32_t iq3s_grid[512];
|
||||
|
||||
#define NEEDS_INIT_IQ_SHMEM
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
@@ -1073,6 +1259,7 @@ const int8_t kvalues_iq4nl_const[16] = {
|
||||
|
||||
shared FLOAT_TYPE kvalues_iq4nl[16];
|
||||
|
||||
#define NEEDS_INIT_IQ_SHMEM
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
|
||||
@@ -55,6 +55,8 @@ const std::vector<std::string> type_names = {
|
||||
"q4_k",
|
||||
"q5_k",
|
||||
"q6_k",
|
||||
"iq1_s",
|
||||
"iq1_m",
|
||||
"iq2_xxs",
|
||||
"iq2_xs",
|
||||
"iq2_s",
|
||||
@@ -182,6 +184,13 @@ std::string to_uppercase(const std::string& input) {
|
||||
return result;
|
||||
}
|
||||
|
||||
bool string_starts_with(const std::string& str, const std::string& prefix) {
|
||||
if (prefix.size() > str.size()) {
|
||||
return false;
|
||||
}
|
||||
return std::equal(prefix.begin(), prefix.end(), str.begin());
|
||||
}
|
||||
|
||||
bool string_ends_with(const std::string& str, const std::string& suffix) {
|
||||
if (suffix.size() > str.size()) {
|
||||
return false;
|
||||
@@ -387,7 +396,7 @@ void process_shaders() {
|
||||
for (const auto& tname : type_names) {
|
||||
// mul mat vec
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
|
||||
std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
|
||||
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
|
||||
|
||||
@@ -1379,7 +1379,7 @@ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tenso
|
||||
(t0->nb[3] == t1->nb[3]);
|
||||
}
|
||||
|
||||
// check if t1 can be represented as a repeatition of t0
|
||||
// check if t1 can be represented as a repetition of t0
|
||||
bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
|
||||
|
||||
@@ -1172,6 +1172,9 @@ extern "C" {
|
||||
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
|
||||
|
||||
/// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n);
|
||||
|
||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
|
||||
22
models/templates/README.md
Normal file
22
models/templates/README.md
Normal file
@@ -0,0 +1,22 @@
|
||||
These templates can be updated with the following commands:
|
||||
|
||||
```bash
|
||||
./scripts/get_chat_template.py CohereForAI/c4ai-command-r-plus tool_use > models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja
|
||||
./scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 default > models/templates/CohereForAI-c4ai-command-r7b-12-2024-default.jinja
|
||||
./scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 rag > models/templates/CohereForAI-c4ai-command-r7b-12-2024-rag.jinja
|
||||
./scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use > models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja
|
||||
./scripts/get_chat_template.py deepseek-ai/DeepSeek-R1-Distill-Llama-8B > models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja
|
||||
./scripts/get_chat_template.py deepseek-ai/DeepSeek-R1-Distill-Qwen-32B > models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja
|
||||
./scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 > models/templates/fireworks-ai-llama-3-firefunction-v2.jinja
|
||||
./scripts/get_chat_template.py google/gemma-2-2b-it > models/templates/google-gemma-2-2b-it.jinja
|
||||
./scripts/get_chat_template.py meetkai/functionary-medium-v3. > models/templates/meetkai-functionary-medium-v3.jinja
|
||||
./scripts/get_chat_template.py meetkai/functionary-medium-v3.2 > models/templates/meetkai-functionary-medium-v3.2.jinja
|
||||
./scripts/get_chat_template.py meta-llama/Llama-3.1-8B-Instruct > models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja
|
||||
./scripts/get_chat_template.py meta-llama/Llama-3.2-3B-Instruct > models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja
|
||||
./scripts/get_chat_template.py meta-llama/Llama-3.3-70B-Instruct > models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja
|
||||
./scripts/get_chat_template.py microsoft/Phi-3.5-mini-instruct > models/templates/microsoft-Phi-3.5-mini-instruct.jinja
|
||||
./scripts/get_chat_template.py mistralai/Mistral-Nemo-Instruct-2407 > models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja
|
||||
./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja
|
||||
./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja
|
||||
./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
|
||||
```
|
||||
@@ -1 +1 @@
|
||||
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}
|
||||
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|><think>\n'}}{% endif %}
|
||||
@@ -1,56 +1 @@
|
||||
{% if not add_generation_prompt is defined %}
|
||||
{% set add_generation_prompt = false %}
|
||||
{% endif %}
|
||||
{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'system' %}
|
||||
{% set ns.system_prompt = message['content'] %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{bos_token}}
|
||||
{{ns.system_prompt}}
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'user' %}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{{'<|User|>' + message['content']}}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'assistant' and message['content'] is none %}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- for tool in message['tool_calls']%}
|
||||
{%- if not ns.is_first %}
|
||||
{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||
{%- set ns.is_first = true -%}
|
||||
{%- else %}
|
||||
{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||
{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'assistant' and message['content'] is not none %}
|
||||
{%- if ns.is_tool %}
|
||||
{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- else %}
|
||||
{% set content = message['content'] %}
|
||||
{% if '</think>' in content %}
|
||||
{% set content = content.split('</think>')[-1] %}
|
||||
{% endif %}
|
||||
{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'tool' %}
|
||||
{%- set ns.is_tool = true -%}
|
||||
{%- if ns.is_output_first %}
|
||||
{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
|
||||
{%- set ns.is_output_first = false %}
|
||||
{%- else %}
|
||||
{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor -%}
|
||||
{% if ns.is_tool %}
|
||||
{{'<|tool▁outputs▁end|>'}}
|
||||
{% endif %}
|
||||
{% if add_generation_prompt and not ns.is_tool %}
|
||||
{{'<|Assistant|>'}}
|
||||
{% endif %}
|
||||
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|><think>\n'}}{% endif %}
|
||||
76
models/templates/llama-cpp-deepseek-r1.jinja
Normal file
76
models/templates/llama-cpp-deepseek-r1.jinja
Normal file
@@ -0,0 +1,76 @@
|
||||
{%- if not add_generation_prompt is defined -%}
|
||||
{%- set add_generation_prompt = false -%}
|
||||
{%- endif -%}
|
||||
{%- set ns = namespace(is_first=false, is_tool_outputs=false, is_output_first=true, system_prompt='') -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'system' -%}
|
||||
{%- set ns.system_prompt = message['content'] -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{bos_token}}
|
||||
{%- if tools %}
|
||||
You can call any of the following function tools to satisfy the user's requests: {{tools | map(attribute='function') | tojson(indent=2)}}
|
||||
|
||||
Example function tool call syntax:
|
||||
|
||||
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>example_function_name
|
||||
```json
|
||||
{
|
||||
"arg1": "some_value"
|
||||
...
|
||||
}
|
||||
```
|
||||
<|tool▁call▁end|><|tool▁calls▁end|>
|
||||
|
||||
{% endif -%}
|
||||
{{ns.system_prompt}}
|
||||
{%- macro flush_tool_outputs() -%}
|
||||
{%- if ns.is_tool_outputs -%}
|
||||
{{- '<|tool▁outputs▁end|><|end▁of▁sentence|>' -}}
|
||||
{%- set ns.is_tool_outputs = false -%}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
{{- flush_tool_outputs() -}}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] != 'tool' -%}
|
||||
{{- flush_tool_outputs() -}}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{{- '<|User|>' + message['content'] + '<|end▁of▁sentence|>' -}}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'assistant' and message['content'] is none -%}
|
||||
{{- '<|Assistant|><|tool▁calls▁begin|>' -}}
|
||||
{%- set ns.is_first = true -%}
|
||||
{%- for tc in message['tool_calls'] -%}
|
||||
{%- if ns.is_first -%}
|
||||
{%- set ns.is_first = false -%}
|
||||
{%- else -%}
|
||||
{{- '\n' -}}
|
||||
{%- endif -%}
|
||||
{%- set tool_name = tc['function']['name'] -%}
|
||||
{%- set tool_args = tc['function']['arguments'] -%}
|
||||
{{- '<|tool▁call▁begin|>' + tc['type'] + '<|tool▁sep|>' + tool_name + '\n' + '```json' + '\n' + tool_args + '\n' + '```' + '<|tool▁call▁end|>' -}}
|
||||
{%- endfor -%}
|
||||
{{- '<|tool▁calls▁end|><|end▁of▁sentence|>' -}}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'assistant' and message['content'] is not none -%}
|
||||
{{- flush_tool_outputs() -}}
|
||||
{%- set content = message['content'] -%}
|
||||
{%- if '</think>' in content -%}
|
||||
{%- set content = content.split('</think>')[-1] -%}
|
||||
{%- endif -%}
|
||||
{{- '<|Assistant|>' + content + '<|end▁of▁sentence|>' -}}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'tool' -%}
|
||||
{%- set ns.is_tool_outputs = true -%}
|
||||
{%- if ns.is_output_first -%}
|
||||
{{- '<|tool▁outputs▁begin|>' -}}
|
||||
{%- set ns.is_output_first = false -%}
|
||||
{%- endif -%}
|
||||
{{- '\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>' -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- flush_tool_outputs() -}}
|
||||
{%- if add_generation_prompt and not ns.is_tool_outputs -%}
|
||||
{{- '<|Assistant|><think>\n' -}}
|
||||
{%- endif -%}
|
||||
5
scripts/get_chat_template.py
Normal file → Executable file
5
scripts/get_chat_template.py
Normal file → Executable file
@@ -7,9 +7,8 @@
|
||||
./scripts/get_chat_template.py model_id [variant]
|
||||
|
||||
Examples:
|
||||
./scripts/get_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct
|
||||
./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use
|
||||
./scripts/get_chat_template.py meta-llama/Llama-3.2-3B-Instruct
|
||||
./scripts/get_chat_template.py CohereForAI/c4ai-command-r-plus tool_use
|
||||
./scripts/get_chat_template.py microsoft/Phi-3.5-mini-instruct
|
||||
'''
|
||||
|
||||
import json
|
||||
|
||||
@@ -1 +1 @@
|
||||
08b538031f7f944e84f472483ef5d26bf5190ead
|
||||
98a61a0d0b43cba06c3ac1c603813639552a0701
|
||||
|
||||
@@ -1186,7 +1186,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||
return;
|
||||
}
|
||||
}
|
||||
LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`) (buffer: `%s`)\n", token, piece.c_str(), grammar.trigger_buffer.c_str());
|
||||
LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,7 +116,7 @@ struct llama_grammar {
|
||||
llama_partial_utf8 partial_utf8;
|
||||
|
||||
// lazy grammars wait for trigger words or tokens before constraining the sampling.
|
||||
// we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
|
||||
// we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
|
||||
// (useful e.g. for tool_choice=required)
|
||||
bool lazy = false;
|
||||
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
|
||||
|
||||
@@ -6,13 +6,13 @@
|
||||
#include <vector>
|
||||
|
||||
#ifdef __GNUC__
|
||||
#ifdef __MINGW32__
|
||||
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||
# if defined(__MINGW32__) && !defined(__clang__)
|
||||
# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||
# else
|
||||
# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||
# endif
|
||||
#else
|
||||
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||
#endif
|
||||
#else
|
||||
#define LLAMA_ATTRIBUTE_FORMAT(...)
|
||||
# define LLAMA_ATTRIBUTE_FORMAT(...)
|
||||
#endif
|
||||
|
||||
//
|
||||
|
||||
@@ -37,7 +37,7 @@ struct llama_kv_cache {
|
||||
bool can_shift = false;
|
||||
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_internal also uses it, so it
|
||||
// for a free KV slot. llama_decode_impl also uses it, so it
|
||||
// cannot be freely changed after a slot has been allocated.
|
||||
uint32_t head = 0;
|
||||
uint32_t size = 0;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
|
||||
@@ -1275,7 +1275,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
|
||||
const bool use_mmap_buffer = true;
|
||||
|
||||
LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, use_mmap_buffer ? "true" : "false");
|
||||
LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false");
|
||||
|
||||
// build a list of buffer types for the CPU and GPU devices
|
||||
pimpl->cpu_buft_list = make_cpu_buft_list(devices);
|
||||
|
||||
@@ -1698,6 +1698,73 @@ struct llama_sampler * llama_sampler_init_penalties(
|
||||
);
|
||||
}
|
||||
|
||||
// top-n-sigma
|
||||
|
||||
struct llama_sampler_top_n_sigma {
|
||||
const float n;
|
||||
};
|
||||
|
||||
static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
|
||||
return "top-n-sigma";
|
||||
}
|
||||
|
||||
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
|
||||
|
||||
// find max logit and calculate mean
|
||||
float max = cur_p->data[0].logit;
|
||||
float logits_sum = 0;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (cur_p->data[i].logit > max) {
|
||||
max = cur_p->data[i].logit;
|
||||
}
|
||||
logits_sum += cur_p->data[i].logit;
|
||||
}
|
||||
float mean = logits_sum/cur_p->size;
|
||||
|
||||
// calculate standard deviation
|
||||
float acc = 0;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
acc += pow(cur_p->data[i].logit - mean, 2);
|
||||
}
|
||||
float std = sqrt(acc/cur_p->size);
|
||||
|
||||
//apply mask
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (cur_p->data[i].logit < max - (ctx->n * std)) {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
}
|
||||
llama_sampler_softmax_impl(cur_p);
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx;
|
||||
return llama_sampler_init_top_n_sigma(ctx->n);
|
||||
}
|
||||
|
||||
static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_top_n_sigma *) smpl->ctx;
|
||||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
|
||||
/* .name = */ llama_sampler_top_n_sigma_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_top_n_sigma_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_top_n_sigma_clone,
|
||||
/* .free = */ llama_sampler_top_n_sigma_free,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
|
||||
return llama_sampler_init(
|
||||
/* .iface = */ &llama_sampler_top_n_sigma_i,
|
||||
/* .ctx = */ new llama_sampler_top_n_sigma {
|
||||
/* .n = */ n,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// DRY
|
||||
|
||||
struct llama_sampler_dry {
|
||||
|
||||
@@ -9430,7 +9430,6 @@ static struct llama_model * llama_model_load_from_file_impl(
|
||||
struct llama_model_params params) {
|
||||
ggml_time_init();
|
||||
|
||||
|
||||
unsigned cur_percentage = 0;
|
||||
if (params.progress_callback == NULL) {
|
||||
params.progress_callback_user_data = &cur_percentage;
|
||||
|
||||
@@ -24,7 +24,10 @@ static common_chat_msg msg_from_json(const json & message) {
|
||||
ret.content = message.at("content");
|
||||
}
|
||||
if (message.contains("tool_plan")) {
|
||||
ret.tool_plan = message.at("tool_plan");
|
||||
ret.reasoning_content = message.at("tool_plan");
|
||||
}
|
||||
if (message.contains("reasoning_content")) {
|
||||
ret.reasoning_content = message.at("reasoning_content");
|
||||
}
|
||||
auto has_tool_calls = message.contains("tool_calls");
|
||||
if (has_tool_calls) {
|
||||
@@ -105,6 +108,7 @@ static std::string dump(const json & j) {
|
||||
static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
|
||||
assert_equals(expected.role, actual.role);
|
||||
assert_equals(expected.content, actual.content);
|
||||
assert_equals(expected.reasoning_content, actual.reasoning_content);
|
||||
assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
|
||||
for (size_t i = 0; i < expected.tool_calls.size(); i++) {
|
||||
const auto & expected_tool_call = expected.tool_calls[i];
|
||||
@@ -176,13 +180,15 @@ struct delta_data {
|
||||
|
||||
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
|
||||
const json & user_message, const json & delta_message, const json & tools,
|
||||
const json & tool_choice) {
|
||||
const json & tool_choice,
|
||||
bool think = false) {
|
||||
common_chat_inputs inputs;
|
||||
inputs.parallel_tool_calls = true;
|
||||
inputs.messages = json::array();
|
||||
inputs.messages.push_back(user_message);
|
||||
inputs.tools = tools;
|
||||
inputs.tool_choice = tool_choice;
|
||||
inputs.extract_reasoning = think;
|
||||
auto params_prefix = common_chat_params_init(tmpl, inputs);
|
||||
|
||||
inputs.messages.push_back(delta_message);
|
||||
@@ -192,17 +198,24 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
|
||||
std::string prefix = params_prefix.prompt;
|
||||
std::string full = params_full.prompt;
|
||||
|
||||
// Check full starts with prefix
|
||||
if (full.find(prefix) != 0) {
|
||||
fprintf(stderr, "Full:\n%s\n\nPrefix:\n%s\n\n", full.c_str(), prefix.c_str());
|
||||
throw std::runtime_error("Full message does not start with prefix");
|
||||
}
|
||||
|
||||
if (full == prefix) {
|
||||
throw std::runtime_error("Full message is the same as the prefix");
|
||||
}
|
||||
|
||||
auto delta = full.substr(prefix.size());
|
||||
size_t common_prefix_length = 0;
|
||||
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
|
||||
if (prefix[i] != full[i]) {
|
||||
break;
|
||||
}
|
||||
if (prefix[i] == '<') {
|
||||
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
|
||||
// but it removes thinking tags for past messages.
|
||||
// The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
|
||||
continue;
|
||||
}
|
||||
common_prefix_length = i + 1;
|
||||
}
|
||||
auto delta = full.substr(common_prefix_length);
|
||||
|
||||
// Strip end tokens
|
||||
for (const auto & end_token : end_tokens) {
|
||||
@@ -223,7 +236,9 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
|
||||
*/
|
||||
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
|
||||
const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
|
||||
bool expect_grammar_triggered = true) {
|
||||
bool expect_grammar_triggered = true,
|
||||
bool test_grammar_if_triggered = true,
|
||||
bool think = false) {
|
||||
common_chat_msg expected_msg = msg_from_json(test_message);
|
||||
|
||||
auto user_message = json{
|
||||
@@ -232,7 +247,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
||||
};
|
||||
|
||||
for (const auto & tool_choice : json({ "auto", "required" })) {
|
||||
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice);
|
||||
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice, think);
|
||||
if (!expected_delta.empty()) {
|
||||
assert_equals(expected_delta, data.delta);
|
||||
}
|
||||
@@ -274,7 +289,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
||||
assert_equals(expect_grammar_triggered, grammar_triggered);
|
||||
}
|
||||
|
||||
if (grammar_triggered && !match_string(constrained, grammar.get())) {
|
||||
if (grammar_triggered && test_grammar_if_triggered && !match_string(constrained, grammar.get())) {
|
||||
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
|
||||
"\n\nGrammar: " + data.params.grammar);
|
||||
}
|
||||
@@ -283,16 +298,33 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
||||
}
|
||||
|
||||
static void test_template_output_parsers() {
|
||||
json text_message {
|
||||
json message_user {
|
||||
{ "role", "user" },
|
||||
{ "content", "Hey there!" },
|
||||
};
|
||||
json message_assist {
|
||||
{ "role", "assistant" },
|
||||
{ "content", "Hello, world!\nWhat's up?" },
|
||||
};
|
||||
json message_assist_thoughts_unparsed_think {
|
||||
{ "role", "assistant" },
|
||||
{ "content", "<think>I'm thinking</think>Hello, world!\nWhat's up?" },
|
||||
};
|
||||
json message_assist_thoughts_unparsed_r7b {
|
||||
{ "role", "assistant" },
|
||||
{ "content", "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?" },
|
||||
};
|
||||
json message_assist_thoughts {
|
||||
{ "role", "assistant" },
|
||||
{ "content", "Hello, world!\nWhat's up?" },
|
||||
{ "reasoning_content", "I'm thinking" },
|
||||
};
|
||||
json tool_calls = json::array({{
|
||||
{ "type", "function" },
|
||||
{ "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
|
||||
}});
|
||||
|
||||
json tool_call_message {
|
||||
json message_assist_call {
|
||||
{ "role", "assistant"},
|
||||
{ "content", {}},
|
||||
{ "tool_calls", {
|
||||
@@ -305,7 +337,34 @@ static void test_template_output_parsers() {
|
||||
},
|
||||
}},
|
||||
};
|
||||
json tool_call_message_with_id {
|
||||
json message_assist_call_thoughts = {
|
||||
{ "role", "assistant" },
|
||||
{ "content", nullptr },
|
||||
{ "reasoning_content", "I'm\nthinking" },
|
||||
{ "tool_calls", {
|
||||
{
|
||||
{ "type", "function" },
|
||||
{ "function", {
|
||||
{ "name", "special_function" },
|
||||
{ "arguments", "{\"arg1\": 1}" },
|
||||
}},
|
||||
},
|
||||
}},
|
||||
};
|
||||
json message_assist_call_thoughts_unparsed = {
|
||||
{ "role", "assistant" },
|
||||
{ "content", "<think>I'm\nthinking</think>" },
|
||||
{ "tool_calls", {
|
||||
{
|
||||
{ "type", "function" },
|
||||
{ "function", {
|
||||
{ "name", "special_function" },
|
||||
{ "arguments", "{\"arg1\": 1}" },
|
||||
}},
|
||||
},
|
||||
}},
|
||||
};
|
||||
json message_assist_call_id {
|
||||
{ "role", "assistant"},
|
||||
{ "content", {}},
|
||||
{ "tool_calls", {
|
||||
@@ -322,10 +381,9 @@ static void test_template_output_parsers() {
|
||||
{ "content", {} },
|
||||
{ "tool_calls", tool_calls }
|
||||
};
|
||||
json tool_call_plan_message_with_idx {
|
||||
json message_assist_call_idx {
|
||||
{ "role", "assistant"},
|
||||
{ "content", {}},
|
||||
{ "tool_plan", "I'm not so sure"},
|
||||
{ "tool_calls", {
|
||||
{
|
||||
{ "type", "function" },
|
||||
@@ -341,8 +399,10 @@ static void test_template_output_parsers() {
|
||||
{ "content", {} },
|
||||
{ "tool_calls", tool_calls }
|
||||
};
|
||||
json message_assist_call_tool_plan_idx = message_assist_call_idx;
|
||||
message_assist_call_tool_plan_idx["tool_plan"] = "I'm thinking";
|
||||
|
||||
auto python_tool_call_message = json{
|
||||
auto python_message_assist_call = json{
|
||||
{ "role", "assistant" },
|
||||
{ "content", {} },
|
||||
{ "tool_calls", json{ {
|
||||
@@ -357,7 +417,7 @@ static void test_template_output_parsers() {
|
||||
} },
|
||||
} } }
|
||||
};
|
||||
auto code_interpreter_tool_call_message = json{
|
||||
auto code_interpreter_message_assist_call = json{
|
||||
{ "role", "assistant" },
|
||||
{ "content", {} },
|
||||
{ "tool_calls", json{ {
|
||||
@@ -374,17 +434,27 @@ static void test_template_output_parsers() {
|
||||
};
|
||||
|
||||
common_chat_inputs inputs_no_tools;
|
||||
inputs_no_tools.messages = {
|
||||
{ { "role", "user" }, { "content", "Hey\nThere" } }
|
||||
};
|
||||
inputs_no_tools.messages = json::array({message_user});
|
||||
inputs_no_tools.extract_reasoning = false;
|
||||
|
||||
common_chat_inputs inputs_tools = inputs_no_tools;
|
||||
inputs_tools.tools = json::array();
|
||||
inputs_tools.tools.push_back(special_function_tool);
|
||||
common_chat_inputs inputs_no_tools_think;
|
||||
inputs_no_tools_think.messages = json::array({message_user});
|
||||
inputs_no_tools_think.extract_reasoning = true;
|
||||
|
||||
common_chat_inputs inputs_tools_builtin = inputs_no_tools;
|
||||
inputs_tools_builtin.tools = json::array();
|
||||
inputs_tools_builtin.tools.push_back(python_tool);
|
||||
common_chat_inputs inputs_tools;
|
||||
inputs_tools.messages = json::array({message_user});
|
||||
inputs_tools.tools = json::array({special_function_tool});
|
||||
inputs_tools.extract_reasoning = false;
|
||||
|
||||
common_chat_inputs inputs_tools_think;
|
||||
inputs_tools_think.messages = json::array({message_user});
|
||||
inputs_tools_think.tools = json::array({special_function_tool});
|
||||
inputs_tools_think.extract_reasoning = true;
|
||||
|
||||
common_chat_inputs inputs_tools_builtin;
|
||||
inputs_tools_builtin.messages = json::array({message_user});
|
||||
inputs_tools_builtin.tools = json::array({python_tool});
|
||||
inputs_tools_builtin.extract_reasoning = false;
|
||||
|
||||
{
|
||||
// Not supported yet
|
||||
@@ -395,15 +465,53 @@ static void test_template_output_parsers() {
|
||||
const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
|
||||
|
||||
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_no_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format);
|
||||
|
||||
test_template(tmpl, end_tokens, tool_call_plan_message_with_idx, tools,
|
||||
"<|START_THINKING|>I'm not so sure<|END_THINKING|>"
|
||||
assert_msg_equals(msg_from_json(message_assist),
|
||||
common_chat_parse(
|
||||
"Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
||||
assert_msg_equals(msg_from_json(message_assist),
|
||||
common_chat_parse(
|
||||
"Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
||||
assert_msg_equals(msg_from_json(message_assist),
|
||||
common_chat_parse(
|
||||
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b),
|
||||
common_chat_parse(
|
||||
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
|
||||
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b),
|
||||
common_chat_parse(
|
||||
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
|
||||
"Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
||||
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts),
|
||||
common_chat_parse(
|
||||
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
|
||||
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING));
|
||||
|
||||
test_template(tmpl, end_tokens, message_assist_call_idx, tools,
|
||||
"<|START_THINKING|><|END_THINKING|>"
|
||||
"<|START_ACTION|>[\n"
|
||||
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
|
||||
"]<|END_ACTION|>");
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
test_template(tmpl, end_tokens, message_assist_call_tool_plan_idx, tools,
|
||||
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
|
||||
"<|START_ACTION|>[\n"
|
||||
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
|
||||
"]<|END_ACTION|>",
|
||||
/* expect_grammar_triggered= */ true,
|
||||
/* test_grammar_if_triggered= */ true,
|
||||
/* think= */ true);
|
||||
test_template(tmpl, end_tokens, message_assist, tools,
|
||||
"<|START_RESPONSE|>Hello, world!\n"
|
||||
"What's up?<|END_RESPONSE|>",
|
||||
/* expect_grammar_triggered= */ false);
|
||||
@@ -423,12 +531,12 @@ static void test_template_output_parsers() {
|
||||
|
||||
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
||||
|
||||
assert_msg_equals(msg_from_json(text_message),
|
||||
assert_msg_equals(msg_from_json(message_assist),
|
||||
common_chat_parse("{\n"
|
||||
" \"response\": \"Hello, world!\\nWhat's up?\"\n"
|
||||
"}",
|
||||
common_chat_params_init(tmpl, inputs_tools).format));
|
||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||
test_template(tmpl, end_tokens, message_assist_call_id, tools,
|
||||
"{\n"
|
||||
" \"tool_calls\": [\n"
|
||||
" {\n"
|
||||
@@ -448,9 +556,9 @@ static void test_template_output_parsers() {
|
||||
|
||||
assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(
|
||||
tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||
tmpl, end_tokens, message_assist_call_id, tools,
|
||||
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
|
||||
}
|
||||
{
|
||||
@@ -473,12 +581,12 @@ static void test_template_output_parsers() {
|
||||
inputs_tools)
|
||||
.format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
"<tool_call>\n"
|
||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</tool_call>");
|
||||
test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
||||
test_template(tmpl, end_tokens, python_message_assist_call, tools,
|
||||
"<tool_call>\n"
|
||||
"{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
|
||||
"</tool_call>");
|
||||
@@ -498,12 +606,12 @@ static void test_template_output_parsers() {
|
||||
inputs_tools_builtin)
|
||||
.format);
|
||||
|
||||
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
||||
// test_template(tmpl, end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, code_interpreter_message_assist_call, llama_3_1_tools,
|
||||
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
|
||||
test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
||||
test_template(tmpl, end_tokens, python_message_assist_call, tools,
|
||||
"<|python_tag|>python.call(code=\"print('hey')\")");
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
}
|
||||
{
|
||||
@@ -513,8 +621,8 @@ static void test_template_output_parsers() {
|
||||
|
||||
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
}
|
||||
{
|
||||
@@ -525,8 +633,8 @@ static void test_template_output_parsers() {
|
||||
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||
common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
"<function=special_function>{\"arg1\": 1}</function>");
|
||||
}
|
||||
{
|
||||
@@ -537,12 +645,12 @@ static void test_template_output_parsers() {
|
||||
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, {},
|
||||
test_template(tmpl, end_tokens, message_assist, {},
|
||||
"all\n"
|
||||
"Hello, world!\n"
|
||||
"What's up?",
|
||||
/* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
"special_function\n"
|
||||
"{\"arg1\": 1}");
|
||||
}
|
||||
@@ -553,23 +661,79 @@ static void test_template_output_parsers() {
|
||||
|
||||
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
|
||||
}
|
||||
{
|
||||
// Original DeepSeek R1 template. Leaves <|tool▁calls▁begin|> and others unclosed. Our logic fixes the prompt.
|
||||
const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"),
|
||||
"<s>", "</s>");
|
||||
std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
|
||||
|
||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
||||
"```json\n"
|
||||
"{\"arg1\": 1}\n"
|
||||
"```<|tool▁call▁end|>");
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think),
|
||||
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts),
|
||||
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts),
|
||||
// Latest template update (ast of 20250209) adds a trailing <think>\n if add_generation_prompt is true.
|
||||
common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
|
||||
// test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
// "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
||||
// "```json\n"
|
||||
// "{\"arg1\": 1}\n"
|
||||
// // Look what's not here: <|tool▁calls▁end|> (also missing the <|end▁of▁sentence|>, but that is removed lazily by the test's delta logic)
|
||||
// "```<|tool▁call▁end|>",
|
||||
// /* expect_grammar_triggered= */ true,
|
||||
// /* test_grammar_if_triggered= */ false);
|
||||
}
|
||||
{
|
||||
// Replacement DeepSeek R1 template. Makes the Distill Qwen 7B/32B models happy to call tools and all.
|
||||
const common_chat_template tmpl(read_file("models/templates/llama-cpp-deepseek-r1.jinja"),
|
||||
"<s>", "</s>");
|
||||
std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
|
||||
|
||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format);
|
||||
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think),
|
||||
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts),
|
||||
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
|
||||
|
||||
assert_msg_equals(msg_from_json(message_assist_call_thoughts_unparsed),
|
||||
common_chat_parse(
|
||||
"<think>I'm\nthinking</think>\n\n"
|
||||
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
||||
"```json\n"
|
||||
"{\"arg1\": 1}\n"
|
||||
"```<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
||||
assert_msg_equals(msg_from_json(message_assist_call_thoughts),
|
||||
common_chat_parse(
|
||||
"<think>I'm\nthinking</think>\n\n"
|
||||
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
||||
"```json\n"
|
||||
"{\"arg1\": 1}\n"
|
||||
"```<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
||||
"```json\n"
|
||||
"{\"arg1\": 1}\n"
|
||||
"```<|tool▁call▁end|><|tool▁calls▁end|>");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -586,16 +750,20 @@ int main(int argc, char ** argv) {
|
||||
std::cout << "|----------|--------|\n";
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string path = argv[i];
|
||||
if (path.rfind(".jinja") != path.size() - 6) {
|
||||
std::cerr << "Skipping non-jinja file: " << path << std::endl;
|
||||
continue;
|
||||
try {
|
||||
std::string path = argv[i];
|
||||
if (path.rfind(".jinja") != path.size() - 6) {
|
||||
std::cerr << "Skipping non-jinja file: " << path << std::endl;
|
||||
continue;
|
||||
}
|
||||
common_chat_template tmpl(read_file(path), "", "");
|
||||
auto parts = string_split(path, "/");
|
||||
auto name = parts[parts.size() - 1];
|
||||
auto format = common_chat_format_name(common_chat_params_init(tmpl, inputs).format);
|
||||
std::cout << "| " << name << " | " << format << " |\n";
|
||||
} catch (const std::exception & e) {
|
||||
std::cerr << "Failed to process " << argv[i] << ": " << e.what() << std::endl;
|
||||
}
|
||||
common_chat_template tmpl(read_file(path), "", "");
|
||||
auto parts = string_split(path, "/");
|
||||
auto name = parts[parts.size() - 1];
|
||||
std::cout << "| " << name << " | " << common_chat_format_name(common_chat_params_init(tmpl, inputs).format)
|
||||
<< " |\n";
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
|
||||
@@ -697,8 +697,8 @@ static std::pair<int, int> test_handcrafted_file(const unsigned int seed) {
|
||||
|
||||
#ifdef _WIN32
|
||||
if (!file) {
|
||||
printf("%s: failed to create tmpfile(), needs elevated privileges on Windows");
|
||||
printf("%s: skipping tests");
|
||||
printf("failed to create tmpfile(), needs elevated privileges on Windows");
|
||||
printf("skipping tests");
|
||||
continue;
|
||||
}
|
||||
#else
|
||||
@@ -1086,8 +1086,8 @@ static std::pair<int, int> test_roundtrip(ggml_backend_dev_t dev, const unsigned
|
||||
|
||||
#ifdef _WIN32
|
||||
if (!file) {
|
||||
printf("%s: failed to create tmpfile(), needs elevated privileges on Windows");
|
||||
printf("%s: skipping tests");
|
||||
printf("failed to create tmpfile(), needs elevated privileges on Windows");
|
||||
printf("skipping tests");
|
||||
return std::make_pair(0, 0);
|
||||
}
|
||||
#else
|
||||
|
||||
@@ -181,6 +181,17 @@ static void test_dry(
|
||||
tester.check();
|
||||
}
|
||||
|
||||
static void test_top_n_sigma(const std::vector<float> & probs, const std::vector<float> & probs_expected, int n) {
|
||||
sampler_tester tester(probs, probs_expected);
|
||||
|
||||
DUMP(&tester.cur_p);
|
||||
tester.apply(llama_sampler_init_top_n_sigma(n));
|
||||
tester.apply(llama_sampler_init_dist (0));
|
||||
DUMP(&tester.cur_p);
|
||||
|
||||
tester.check();
|
||||
}
|
||||
|
||||
static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
|
||||
) {
|
||||
sampler_tester tester(n_vocab);
|
||||
@@ -348,6 +359,10 @@ int main(void) {
|
||||
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {});
|
||||
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
|
||||
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1.00f);
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.00f);
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3.00f);
|
||||
|
||||
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
|
||||
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
|
||||
test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f);
|
||||
|
||||
Reference in New Issue
Block a user