mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-23 16:37:33 +03:00
Compare commits
1 Commits
b4720
...
revert-118
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f30aca84b2 |
@@ -1,6 +1,6 @@
|
||||
ARG UBUNTU_VERSION=22.04
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG CUDA_VERSION=12.4.0
|
||||
ARG CUDA_VERSION=12.6.0
|
||||
# Target the CUDA build image
|
||||
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
ARG UBUNTU_VERSION=22.04
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG MUSA_VERSION=rc3.1.1
|
||||
ARG MUSA_VERSION=rc3.1.0
|
||||
# Target the MUSA build image
|
||||
ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
|
||||
32
.github/workflows/build.yml
vendored
32
.github/workflows/build.yml
vendored
@@ -401,35 +401,7 @@ jobs:
|
||||
run: |
|
||||
cd build
|
||||
# This is using llvmpipe and runs slower than other backends
|
||||
ctest -L main --verbose --timeout 2700
|
||||
|
||||
- name: Determine tag name
|
||||
id: tag
|
||||
shell: bash
|
||||
run: |
|
||||
BUILD_NUMBER="$(git rev-list --count HEAD)"
|
||||
SHORT_HASH="$(git rev-parse --short=7 HEAD)"
|
||||
if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then
|
||||
echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT
|
||||
else
|
||||
SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-')
|
||||
echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Pack artifacts
|
||||
id: pack_artifacts
|
||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||
run: |
|
||||
cp LICENSE ./build/bin/
|
||||
cp examples/run/linenoise.cpp/LICENSE ./build/bin/LICENSE.linenoise.cpp
|
||||
zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip ./build/bin/*
|
||||
|
||||
- name: Upload artifacts
|
||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip
|
||||
name: llama-bin-ubuntu-vulkan-x64.zip
|
||||
ctest -L main --verbose --timeout 1800
|
||||
|
||||
ubuntu-22-cmake-hip:
|
||||
runs-on: ubuntu-22.04
|
||||
@@ -471,7 +443,7 @@ jobs:
|
||||
|
||||
ubuntu-22-cmake-musa:
|
||||
runs-on: ubuntu-22.04
|
||||
container: mthreads/musa:rc3.1.1-devel-ubuntu22.04
|
||||
container: mthreads/musa:rc3.1.0-devel-ubuntu22.04
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
||||
15
README.md
15
README.md
@@ -519,18 +519,5 @@ If your issue is with model generation quality, then please at least scan the fo
|
||||
- [Aligning language models to follow instructions](https://openai.com/research/instruction-following)
|
||||
- [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155)
|
||||
|
||||
## Completions
|
||||
Command-line completion is available for some environments.
|
||||
#### References
|
||||
|
||||
#### Bash Completion
|
||||
```bash
|
||||
$ build/bin/llama-cli --completion-bash > ~/.llama-completion.bash
|
||||
$ source ~/.llama-completion.bash
|
||||
```
|
||||
Optionally this can be added to your `.bashrc` or `.bash_profile` to load it
|
||||
automatically. For example:
|
||||
```console
|
||||
$ echo "source ~/.llama-completion.bash" >> ~/.bashrc
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
@@ -96,22 +96,6 @@ if (LLAMA_LLGUIDANCE)
|
||||
include(ExternalProject)
|
||||
set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source)
|
||||
set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release)
|
||||
|
||||
# Set the correct library file extension based on platform
|
||||
if (WIN32)
|
||||
set(LLGUIDANCE_LIB_NAME "llguidance.lib")
|
||||
# Add Windows-specific libraries
|
||||
set(LLGUIDANCE_PLATFORM_LIBS
|
||||
ws2_32 # Windows Sockets API
|
||||
userenv # For GetUserProfileDirectoryW
|
||||
ntdll # For NT functions
|
||||
bcrypt # For BCryptGenRandom
|
||||
)
|
||||
else()
|
||||
set(LLGUIDANCE_LIB_NAME "libllguidance.a")
|
||||
set(LLGUIDANCE_PLATFORM_LIBS "")
|
||||
endif()
|
||||
|
||||
ExternalProject_Add(llguidance_ext
|
||||
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
|
||||
# v0.6.12:
|
||||
@@ -122,18 +106,17 @@ if (LLAMA_LLGUIDANCE)
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND cargo build --release
|
||||
INSTALL_COMMAND ""
|
||||
BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/${LLGUIDANCE_LIB_NAME} ${LLGUIDANCE_PATH}/llguidance.h
|
||||
BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/libllguidance.a ${LLGUIDANCE_PATH}/llguidance.h
|
||||
UPDATE_COMMAND ""
|
||||
)
|
||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_LLGUIDANCE)
|
||||
|
||||
add_library(llguidance STATIC IMPORTED)
|
||||
set_target_properties(llguidance PROPERTIES IMPORTED_LOCATION ${LLGUIDANCE_PATH}/${LLGUIDANCE_LIB_NAME})
|
||||
set_target_properties(llguidance PROPERTIES IMPORTED_LOCATION ${LLGUIDANCE_PATH}/libllguidance.a)
|
||||
add_dependencies(llguidance llguidance_ext)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH})
|
||||
# Add platform libraries to the main target
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance ${LLGUIDANCE_PLATFORM_LIBS})
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance)
|
||||
endif ()
|
||||
|
||||
target_include_directories(${TARGET} PUBLIC .)
|
||||
|
||||
135
common/arg.cpp
135
common/arg.cpp
@@ -365,112 +365,6 @@ static void common_params_print_usage(common_params_context & ctx_arg) {
|
||||
print_options(specific_options);
|
||||
}
|
||||
|
||||
static void common_params_print_completion(common_params_context & ctx_arg) {
|
||||
std::vector<common_arg *> common_options;
|
||||
std::vector<common_arg *> sparam_options;
|
||||
std::vector<common_arg *> specific_options;
|
||||
|
||||
for (auto & opt : ctx_arg.options) {
|
||||
if (opt.is_sparam) {
|
||||
sparam_options.push_back(&opt);
|
||||
} else if (opt.in_example(ctx_arg.ex)) {
|
||||
specific_options.push_back(&opt);
|
||||
} else {
|
||||
common_options.push_back(&opt);
|
||||
}
|
||||
}
|
||||
|
||||
printf("_llama_completions() {\n");
|
||||
printf(" local cur prev opts\n");
|
||||
printf(" COMPREPLY=()\n");
|
||||
printf(" cur=\"${COMP_WORDS[COMP_CWORD]}\"\n");
|
||||
printf(" prev=\"${COMP_WORDS[COMP_CWORD-1]}\"\n\n");
|
||||
|
||||
printf(" opts=\"");
|
||||
auto print_options = [](const std::vector<common_arg *> & options) {
|
||||
for (const common_arg * opt : options) {
|
||||
for (const char * arg : opt->args) {
|
||||
printf("%s ", arg);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
print_options(common_options);
|
||||
print_options(sparam_options);
|
||||
print_options(specific_options);
|
||||
printf("\"\n\n");
|
||||
|
||||
printf(" case \"$prev\" in\n");
|
||||
printf(" --model)\n");
|
||||
printf(" COMPREPLY=( $(compgen -f -X '!*.gguf' -- \"$cur\") $(compgen -d -- \"$cur\") )\n");
|
||||
printf(" return 0\n");
|
||||
printf(" ;;\n");
|
||||
printf(" --grammar-file)\n");
|
||||
printf(" COMPREPLY=( $(compgen -f -X '!*.gbnf' -- \"$cur\") $(compgen -d -- \"$cur\") )\n");
|
||||
printf(" return 0\n");
|
||||
printf(" ;;\n");
|
||||
printf(" --chat-template-file)\n");
|
||||
printf(" COMPREPLY=( $(compgen -f -X '!*.jinja' -- \"$cur\") $(compgen -d -- \"$cur\") )\n");
|
||||
printf(" return 0\n");
|
||||
printf(" ;;\n");
|
||||
printf(" *)\n");
|
||||
printf(" COMPREPLY=( $(compgen -W \"${opts}\" -- \"$cur\") )\n");
|
||||
printf(" return 0\n");
|
||||
printf(" ;;\n");
|
||||
printf(" esac\n");
|
||||
printf("}\n\n");
|
||||
|
||||
std::set<std::string> executables = {
|
||||
"llama-batched",
|
||||
"llama-batched-bench",
|
||||
"llama-bench",
|
||||
"llama-cli",
|
||||
"llama-convert-llama2c-to-ggml",
|
||||
"llama-cvector-generator",
|
||||
"llama-embedding",
|
||||
"llama-eval-callback",
|
||||
"llama-export-lora",
|
||||
"llama-gbnf-validator",
|
||||
"llama-gen-docs",
|
||||
"llama-gguf",
|
||||
"llama-gguf-hash",
|
||||
"llama-gguf-split",
|
||||
"llama-gritlm",
|
||||
"llama-imatrix",
|
||||
"llama-infill",
|
||||
"llama-llava-cli",
|
||||
"llama-llava-clip-quantize-cli",
|
||||
"llama-lookahead",
|
||||
"llama-lookup",
|
||||
"llama-lookup-create",
|
||||
"llama-lookup-merge",
|
||||
"llama-lookup-stats",
|
||||
"llama-minicpmv-cli",
|
||||
"llama-parallel",
|
||||
"llama-passkey",
|
||||
"llama-perplexity",
|
||||
"llama-q8dot",
|
||||
"llama-quantize",
|
||||
"llama-quantize-stats",
|
||||
"llama-qwen2vl-cli",
|
||||
"llama-retrieval",
|
||||
"llama-run",
|
||||
"llama-save-load-state",
|
||||
"llama-server",
|
||||
"llama-simple",
|
||||
"llama-simple-chat",
|
||||
"llama-speculative",
|
||||
"llama-speculative-simple",
|
||||
"llama-tokenize",
|
||||
"llama-tts",
|
||||
"llama-vdot"
|
||||
};
|
||||
|
||||
for (const auto& exe : executables) {
|
||||
printf("complete -F _llama_completions %s\n", exe.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & value) {
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
auto dev_names = string_split<std::string>(value, ',');
|
||||
@@ -532,10 +426,6 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
|
||||
}
|
||||
exit(0);
|
||||
}
|
||||
if (ctx_arg.params.completion) {
|
||||
common_params_print_completion(ctx_arg);
|
||||
exit(0);
|
||||
}
|
||||
} catch (const std::invalid_argument & ex) {
|
||||
fprintf(stderr, "%s\n", ex.what());
|
||||
ctx_arg.params = params_org;
|
||||
@@ -604,13 +494,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
exit(0);
|
||||
}
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"--completion-bash"},
|
||||
"print source-able bash completion script for llama.cpp",
|
||||
[](common_params & params) {
|
||||
params.completion = true;
|
||||
}
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"--verbose-prompt"},
|
||||
string_format("print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false"),
|
||||
@@ -1063,13 +946,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.sampling.min_p = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--top-nsigma"}, "N",
|
||||
string_format("top-n-sigma sampling (default: %.1f, -1.0 = disabled)", params.sampling.top_n_sigma),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.top_n_sigma = std::stof(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN}).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--xtc-probability"}, "N",
|
||||
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
|
||||
@@ -2099,17 +1975,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.use_jinja = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA"));
|
||||
add_opt(common_arg(
|
||||
{"--reasoning-format"}, "FORMAT",
|
||||
"reasoning format (default: deepseek; allowed values: deepseek, none)\n"
|
||||
"controls whether thought tags are extracted from the response, and in which format they're returned. 'none' leaves thoughts unparsed in `message.content`, 'deepseek' puts them in `message.reasoning_content` (for DeepSeek R1 & Command R7B only).\n"
|
||||
"only supported for non-streamed responses",
|
||||
[](common_params & params, const std::string & value) {
|
||||
/**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; }
|
||||
else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; }
|
||||
else { std::invalid_argument("invalid value"); }
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK"));
|
||||
add_opt(common_arg(
|
||||
{"--chat-template"}, "JINJA_TEMPLATE",
|
||||
string_format(
|
||||
|
||||
323
common/chat.cpp
323
common/chat.cpp
@@ -12,13 +12,11 @@ std::string common_chat_format_name(common_chat_format format) {
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING: return "DeepSeek R1 (extract reasoning)";
|
||||
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
|
||||
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)";
|
||||
default:
|
||||
throw std::runtime_error("Unknown chat format");
|
||||
}
|
||||
@@ -107,6 +105,7 @@ static common_chat_msg parse_json_tool_calls(
|
||||
std::sregex_iterator rend;
|
||||
std::sregex_iterator rit(it, end, function_regex);
|
||||
if (rit == rend) {
|
||||
fprintf(stderr, "No more tool calls found\n");
|
||||
result.content += std::string(it, end);
|
||||
break;
|
||||
}
|
||||
@@ -116,21 +115,14 @@ static common_chat_msg parse_json_tool_calls(
|
||||
|
||||
json arguments;
|
||||
if (!parse_json(it, end, arguments)) {
|
||||
throw std::runtime_error("Failed to parse json tool call arguments: " + input);
|
||||
throw std::runtime_error("Failed to parse json tool call arguments");
|
||||
}
|
||||
if (!std::regex_search(it, end, match, close_regex)) {
|
||||
throw std::runtime_error("Malformed input, missing closing pattern: " + input);
|
||||
throw std::runtime_error("Malformed input, missing closing pattern");
|
||||
}
|
||||
it = match.suffix().first;
|
||||
result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
|
||||
}
|
||||
|
||||
if (!result.tool_calls.empty()) {
|
||||
if (!string_strip(result.content).empty()) {
|
||||
LOG_WRN("Content found with tool calls: %s\n", result.content.c_str());
|
||||
}
|
||||
result.content = "";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -142,11 +134,11 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in
|
||||
result.role = "assistant";
|
||||
const auto process_tool_calls = [&](const json & tool_calls) {
|
||||
for (const auto & tool_call : tool_calls) {
|
||||
const auto & arguments = tool_call.at("arguments");
|
||||
const auto & arguments = tool_call["arguments"];
|
||||
result.tool_calls.push_back({
|
||||
tool_call.at("name"),
|
||||
tool_call["name"],
|
||||
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
tool_call.contains("id") ? tool_call.at("id") : "",
|
||||
tool_call.contains("id") ? tool_call["id"] : "",
|
||||
});
|
||||
}
|
||||
};
|
||||
@@ -163,7 +155,7 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in
|
||||
|
||||
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
|
||||
for (const auto & tool : tools) {
|
||||
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
|
||||
if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) {
|
||||
LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str());
|
||||
continue;
|
||||
}
|
||||
@@ -198,27 +190,27 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
|
||||
|
||||
auto tool_call_schemas = json::array();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
const auto & function = tool["function"];
|
||||
auto tool_schema = json {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", function.at("name")},
|
||||
{"const", function["name"]},
|
||||
}},
|
||||
{"arguments", function.at("parameters")},
|
||||
{"arguments", function["parameters"]},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments"})},
|
||||
};
|
||||
if (function.contains("description")) {
|
||||
tool_schema["description"] = function.at("description");
|
||||
tool_schema["description"] = function["description"];
|
||||
}
|
||||
if (inputs.parallel_tool_calls) {
|
||||
tool_schema.at("properties")["id"] = {
|
||||
tool_schema["properties"]["id"] = {
|
||||
{"type", "string"},
|
||||
{"minLength", 4},
|
||||
};
|
||||
tool_schema.at("required").push_back("id");
|
||||
tool_schema["required"].push_back("id");
|
||||
}
|
||||
tool_call_schemas.emplace_back(tool_schema);
|
||||
});
|
||||
@@ -283,21 +275,21 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) {
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
if (data.contains("tool_calls")) {
|
||||
for (const auto & tool_call : data.at("tool_calls")) {
|
||||
for (const auto & tool_call : data["tool_calls"]) {
|
||||
result.tool_calls.push_back({
|
||||
tool_call.at("name"),
|
||||
tool_call.at("arguments").dump(),
|
||||
tool_call.contains("id") ? tool_call.at("id") : "",
|
||||
tool_call["name"],
|
||||
tool_call["arguments"].dump(),
|
||||
tool_call.contains("id") ? tool_call["id"] : "",
|
||||
});
|
||||
}
|
||||
} else if (data.contains("tool_call")) {
|
||||
result.tool_calls.push_back({
|
||||
data.at("tool_call").at("name"),
|
||||
data.at("tool_call").at("arguments").dump(),
|
||||
data["tool_call"]["name"],
|
||||
data["tool_call"]["arguments"].dump(),
|
||||
/* id= */ "",
|
||||
});
|
||||
} else if (data.contains("response")) {
|
||||
const auto & response = data.at("response");
|
||||
const auto & response = data["response"];
|
||||
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
|
||||
}
|
||||
return result;
|
||||
@@ -309,7 +301,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
const auto & function = tool["function"];
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
@@ -317,9 +309,9 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
|
||||
// It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object.
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", function.at("name")},
|
||||
{"const", function["name"]},
|
||||
}},
|
||||
{"arguments", function.at("parameters")},
|
||||
{"arguments", function["parameters"]},
|
||||
{"id", {
|
||||
{"type", "string"},
|
||||
// Nemo's template expects a 9-character alphanumeric ID.
|
||||
@@ -354,7 +346,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
const auto & function = tool["function"];
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
@@ -365,9 +357,9 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
|
||||
}},
|
||||
{"tool_name", {
|
||||
{"type", "string"},
|
||||
{"const", function.at("name")},
|
||||
{"const", function["name"]},
|
||||
}},
|
||||
{"parameters", function.at("parameters")},
|
||||
{"parameters", function["parameters"]},
|
||||
}},
|
||||
{"required", json::array({"tool_call_id", "tool_name", "parameters"})},
|
||||
});
|
||||
@@ -390,65 +382,39 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
|
||||
"<|END_THINKING|>",
|
||||
"<|END_ACTION|>",
|
||||
};
|
||||
auto adjusted_messages = json::array();
|
||||
for (const auto & msg : inputs.messages) {
|
||||
auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
|
||||
auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
|
||||
if (has_reasoning_content && has_tool_calls) {
|
||||
auto adjusted_message = msg;
|
||||
adjusted_message["tool_plan"] = msg.at("reasoning_content");
|
||||
adjusted_message.erase("reasoning_content");
|
||||
adjusted_messages.push_back(adjusted_message);
|
||||
} else {
|
||||
adjusted_messages.push_back(msg);
|
||||
}
|
||||
}
|
||||
data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {});
|
||||
data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING : COMMON_CHAT_FORMAT_COMMAND_R7B;
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) {
|
||||
static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|>)([\\s\\S\\n\\r]*)");
|
||||
static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
|
||||
static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>");
|
||||
|
||||
static common_chat_msg common_chat_parse_command_r7b(const std::string & input) {
|
||||
static std::regex response_regex("<\\|START_RESPONSE\\|>([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>");
|
||||
static std::regex thought_action_regex("<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|><\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
|
||||
std::smatch match;
|
||||
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
|
||||
std::string rest = input;
|
||||
|
||||
if (std::regex_match(rest, match, thought_regex)) {
|
||||
if (extract_reasoning) {
|
||||
result.reasoning_content = match[2].str();
|
||||
} else if (!match[2].str().empty()) {
|
||||
// Let the unparsed thinking tags through in content only if their insides aren't empty.
|
||||
result.content = match[1].str();
|
||||
}
|
||||
rest = match[3].str();
|
||||
}
|
||||
if (std::regex_match(rest, match, action_regex)) {
|
||||
auto actions_str = match[1].str();
|
||||
if (std::regex_match(input, match, response_regex)) {
|
||||
result.content = match[1].str();
|
||||
} else if (std::regex_match(input, match, thought_action_regex)) {
|
||||
result.tool_plan = match[1].str();
|
||||
auto actions_str = match[2].str();
|
||||
auto actions = json::parse(actions_str);
|
||||
for (const auto & action : actions) {
|
||||
result.tool_calls.push_back({
|
||||
/* .name = */ action.at("tool_name"),
|
||||
/* .arguments = */ action.at("parameters").dump(),
|
||||
/* .id = */ action.at("tool_call_id"),
|
||||
/* .name = */ action["tool_name"],
|
||||
/* .arguments = */ action["parameters"].dump(),
|
||||
/* .id = */ action["tool_call_id"],
|
||||
});
|
||||
}
|
||||
} else if (std::regex_match(rest, match, response_regex)) {
|
||||
auto response = match[1].str();
|
||||
result.content += response;
|
||||
} else {
|
||||
result.content += rest;
|
||||
LOG_ERR("Failed to parse command_r output");
|
||||
result.content = input;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
|
||||
if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
|
||||
if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
|
||||
throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
|
||||
}
|
||||
const auto & parameters_properties = parameters.at("properties");
|
||||
@@ -502,9 +468,9 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
||||
};
|
||||
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
builder.resolve_refs(parameters);
|
||||
|
||||
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
|
||||
@@ -580,90 +546,34 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo
|
||||
|
||||
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
common_chat_params data;
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != "required" && inputs.json_schema.is_null();
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
tool_rules.push_back(builder.add_rule(name + "-call",
|
||||
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
|
||||
"```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
|
||||
});
|
||||
// Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
|
||||
// so we accept common variants (then it's all constrained)
|
||||
builder.add_rule("root",
|
||||
"( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" ) "
|
||||
"(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
|
||||
"\"<|tool▁calls▁end|>\""
|
||||
" space");
|
||||
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({"<|tool_calls_begin|>", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({"<|tool calls begin|>", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({"<|tool\\_calls\\_begin|>", /* .at_start = */ false});
|
||||
data.preserved_tokens = {
|
||||
"<think>",
|
||||
"</think>",
|
||||
"<|tool▁sep|>",
|
||||
"<|tool▁calls▁end|",
|
||||
"<|tool▁call▁end|>",
|
||||
};
|
||||
}, grammar_options);
|
||||
}
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
tool_rules.push_back(builder.add_rule(name + "-call",
|
||||
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
|
||||
});
|
||||
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
|
||||
data.preserved_tokens = {
|
||||
"<|tool▁sep|>",
|
||||
"<|tool▁call▁end|>",
|
||||
};
|
||||
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
|
||||
}, grammar_options);
|
||||
auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
|
||||
// Hacks to fix the official (broken) prompt.
|
||||
// It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
|
||||
// until the official template is fixed.
|
||||
if (tmpl.source().find("{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}") != std::string::npos) {
|
||||
// Don't leave the chat dangling after tool results
|
||||
if (string_ends_with(prompt, "<|tool▁outputs▁end|>")) {
|
||||
prompt += "<|end▁of▁sentence|>";
|
||||
if (inputs.add_generation_prompt) {
|
||||
prompt += "<|Assistant|>";
|
||||
}
|
||||
}
|
||||
// Fix up tool call delta example added by Minja
|
||||
prompt = std::regex_replace(
|
||||
prompt,
|
||||
std::regex("(<|tool▁call▁end|>)[\\s\\r\\n]*(<|tool▁outputs▁begin|>|<|User|>)"),
|
||||
"$1<|tool▁calls▁end|><|end▁of▁sentence|>$2");
|
||||
}
|
||||
data.prompt = prompt;
|
||||
data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING : COMMON_CHAT_FORMAT_DEEPSEEK_R1;
|
||||
data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) {
|
||||
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) {
|
||||
static std::regex trigger_regex("<|tool▁calls▁begin|>");
|
||||
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
||||
static std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
|
||||
static std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
|
||||
static std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>");
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
std::smatch match;
|
||||
if (std::regex_match(input, match, reasoning_content_regex)) {
|
||||
std::string rest;
|
||||
if (extract_reasoning) {
|
||||
msg.reasoning_content = string_strip(match[2].str());
|
||||
} else {
|
||||
msg.content = match[1].str();
|
||||
}
|
||||
rest = match[3].str();
|
||||
|
||||
if (std::regex_search(rest, match, tool_calls_regex)) {
|
||||
auto tool_calls = match[1].str();
|
||||
auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex);
|
||||
msg.tool_calls = std::move(msg2.tool_calls);
|
||||
} else {
|
||||
msg.content += std::string(rest.begin() + rest.find_first_not_of(" \r\n"), rest.end());
|
||||
}
|
||||
} else {
|
||||
msg.content = input;
|
||||
}
|
||||
return msg;
|
||||
static std::regex close_regex("```<|tool▁call▁end|>");
|
||||
return parse_json_tool_calls(input, trigger_regex, function_regex, close_regex);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
@@ -673,20 +583,20 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
|
||||
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||
});
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
const auto & function = tool["function"];
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", function.at("name")},
|
||||
{"const", function["name"]},
|
||||
}},
|
||||
{"arguments", function.at("parameters")},
|
||||
{"arguments", function["parameters"]},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments", "id"})},
|
||||
});
|
||||
@@ -718,15 +628,15 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> first_tool_rules;
|
||||
std::vector<std::string> subsequent_tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
||||
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
|
||||
@@ -806,9 +716,9 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
const auto & parameters = function.at("parameters");
|
||||
std::string name = function.at("name");
|
||||
const auto & function = tool["function"];
|
||||
const auto & parameters = function["parameters"];
|
||||
std::string name = function["name"];
|
||||
if (name == "python" || name == "ipython") {
|
||||
if (!parameters.contains("type")) {
|
||||
throw std::runtime_error("Missing type in python tool");
|
||||
@@ -879,9 +789,9 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
builder.resolve_refs(parameters);
|
||||
tool_rules.push_back(builder.add_schema(name + "-call", {
|
||||
{"type", "object"},
|
||||
@@ -929,9 +839,9 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
|
||||
if (!parse_json(it, end, call)) {
|
||||
throw std::runtime_error("Failed to parse json tool call");
|
||||
}
|
||||
const auto & arguments = call.at("arguments");
|
||||
const auto & arguments = call["arguments"];
|
||||
result.tool_calls.push_back({
|
||||
call.at("name"),
|
||||
call["name"],
|
||||
arguments.dump(),
|
||||
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
/* id= */ "",
|
||||
@@ -974,72 +884,47 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
|
||||
}
|
||||
|
||||
common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
auto has_tools = !inputs.tools.is_null() && inputs.tool_choice != "none";
|
||||
LOG_DBG("[%s] has_tools=%s\n", __func__, has_tools ? "true" : "false");
|
||||
|
||||
if (has_tools && !inputs.grammar.empty()) {
|
||||
throw std::runtime_error("Cannot specify grammar with tools");
|
||||
}
|
||||
|
||||
const auto & src = tmpl.source();
|
||||
const auto & caps = tmpl.original_caps();
|
||||
|
||||
if (inputs.tools.is_array()) {
|
||||
if (inputs.tool_choice != "none" && !inputs.grammar.empty()) {
|
||||
throw std::runtime_error("Cannot specify grammar with tools");
|
||||
}
|
||||
if (caps.supports_tool_calls && !caps.supports_tools) {
|
||||
LOG_WRN("Template supports tool calls but does not natively describe tools. The fallback behaviour used may produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
|
||||
}
|
||||
}
|
||||
|
||||
// DeepSeek R1: use handler in all cases except json schema (thinking / tools).
|
||||
if (src.find("<|tool▁calls▁begin|>") != std::string::npos && inputs.json_schema.is_null()) {
|
||||
return common_chat_params_init_deepseek_r1(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Command R7B: : use handler in all cases except json schema (thinking / tools).
|
||||
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) {
|
||||
return common_chat_params_init_command_r7b(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Use generic handler when mixing tools + JSON schema.
|
||||
// TODO: support that mix in handlers below.
|
||||
if ((!inputs.tools.is_array() && inputs.json_schema.is_object())) {
|
||||
return common_chat_params_init_generic(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
|
||||
if (src.find(">>>all") != std::string::npos) {
|
||||
// Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when
|
||||
return common_chat_params_init_functionary_v3_2(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
|
||||
if (src.find(" functools[") != std::string::npos) {
|
||||
// Firefunction v2 requires datetime and functions in the context, even w/o tools.
|
||||
return common_chat_params_init_firefunction_v2(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Plain handler (no tools)
|
||||
if (inputs.tools.is_null() || inputs.tool_choice == "none") {
|
||||
if (!has_tools) {
|
||||
return common_chat_params_init_without_tools(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
|
||||
if (src.find("<tool_call>") != std::string::npos) {
|
||||
return common_chat_params_init_hermes_2_pro(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Functionary v3.1 (w/ tools)
|
||||
if (src.find("<|start_header_id|>") != std::string::npos
|
||||
&& src.find("<function=") != std::string::npos) {
|
||||
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Llama 3.1, 3.2, 3.3 (w/ tools)
|
||||
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
||||
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
|
||||
}
|
||||
|
||||
// Mistral Nemo (w/ tools)
|
||||
if (src.find("<|tool▁calls▁begin|>") != std::string::npos) {
|
||||
return common_chat_params_init_deepseek_r1(tmpl, inputs);
|
||||
}
|
||||
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
||||
return common_chat_params_init_mistral_nemo(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Generic fallback
|
||||
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos) {
|
||||
return common_chat_params_init_command_r7b(tmpl, inputs);
|
||||
}
|
||||
return common_chat_params_init_generic(tmpl, inputs);
|
||||
}
|
||||
|
||||
@@ -1064,9 +949,7 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
|
||||
return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true);
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
|
||||
return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ false);
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING:
|
||||
return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ true);
|
||||
return common_chat_parse_deepseek_r1(input);
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
|
||||
return common_chat_parse_functionary_v3_2(input);
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
|
||||
@@ -1076,9 +959,7 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
|
||||
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
|
||||
return common_chat_parse_firefunction_v2(input);
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B:
|
||||
return common_chat_parse_command_r7b(input, /* extract_reasoning= */ false);
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING:
|
||||
return common_chat_parse_command_r7b(input, /* extract_reasoning= */ true);
|
||||
return common_chat_parse_command_r7b(input);
|
||||
default:
|
||||
throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ struct common_chat_inputs {
|
||||
bool stream;
|
||||
std::string grammar;
|
||||
bool add_generation_prompt = true;
|
||||
bool extract_reasoning = true;
|
||||
};
|
||||
|
||||
enum common_chat_format {
|
||||
@@ -29,13 +28,11 @@ enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
|
||||
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
||||
@@ -140,7 +140,6 @@ struct common_params_sampling {
|
||||
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
||||
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float top_n_sigma = -1.00f;// -1.0 = disabled
|
||||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
bool ignore_eos = false;
|
||||
@@ -203,11 +202,6 @@ struct common_params_vocoder {
|
||||
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
||||
};
|
||||
|
||||
enum common_reasoning_format {
|
||||
COMMON_REASONING_FORMAT_NONE,
|
||||
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`
|
||||
};
|
||||
|
||||
struct common_params {
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
int32_t n_ctx = 4096; // context size
|
||||
@@ -298,7 +292,6 @@ struct common_params {
|
||||
bool kl_divergence = false; // compute KL divergence
|
||||
|
||||
bool usage = false; // print usage
|
||||
bool completion = false; // print source-able completion script
|
||||
bool use_color = false; // use color to distinguish generations and inputs
|
||||
bool special = false; // enable special token output
|
||||
bool interactive = false; // interactive mode
|
||||
@@ -353,7 +346,6 @@ struct common_params {
|
||||
std::string chat_template = ""; // NOLINT
|
||||
bool use_jinja = false; // NOLINT
|
||||
bool enable_chat_template = true;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
|
||||
std::vector<std::string> api_keys;
|
||||
|
||||
@@ -631,7 +623,7 @@ struct common_chat_msg {
|
||||
std::string role;
|
||||
std::string content;
|
||||
std::vector<common_tool_call> tool_calls;
|
||||
std::string reasoning_content = "";
|
||||
std::string tool_plan = "";
|
||||
};
|
||||
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#include "log.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <condition_variable>
|
||||
#include <cstdarg>
|
||||
#include <cstdio>
|
||||
|
||||
@@ -134,11 +134,11 @@ std::string common_params_sampling::print() const {
|
||||
snprintf(result, sizeof(result),
|
||||
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
||||
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
||||
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
|
||||
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
|
||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
||||
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
||||
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
||||
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
|
||||
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
|
||||
mirostat, mirostat_eta, mirostat_tau);
|
||||
|
||||
return std::string(result);
|
||||
@@ -151,6 +151,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
|
||||
lparams.no_perf = params.no_perf;
|
||||
|
||||
std::vector<const char *> trigger_words;
|
||||
trigger_words.reserve(params.grammar_trigger_words.size());
|
||||
for (const auto & str : params.grammar_trigger_words) {
|
||||
trigger_words.push_back(str.word.c_str());
|
||||
}
|
||||
|
||||
struct llama_sampler * grmr;
|
||||
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
||||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
@@ -159,12 +165,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
} else {
|
||||
std::vector<const char *> trigger_words;
|
||||
trigger_words.reserve(params.grammar_trigger_words.size());
|
||||
for (const auto & str : params.grammar_trigger_words) {
|
||||
trigger_words.push_back(str.word.c_str());
|
||||
}
|
||||
|
||||
grmr = params.grammar_lazy
|
||||
? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root",
|
||||
trigger_words.data(), trigger_words.size(),
|
||||
@@ -188,51 +188,45 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
params.logit_bias.data()));
|
||||
|
||||
if (params.mirostat == 0) {
|
||||
if (params.top_n_sigma >= 0) {
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
|
||||
} else {
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
{
|
||||
std::vector<const char *> c_breakers;
|
||||
c_breakers.reserve(params.dry_sequence_breakers.size());
|
||||
for (const auto & str : params.dry_sequence_breakers) {
|
||||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
{
|
||||
std::vector<const char *> c_breakers;
|
||||
c_breakers.reserve(params.dry_sequence_breakers.size());
|
||||
for (const auto & str : params.dry_sequence_breakers) {
|
||||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_INFILL:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_INFILL:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
}
|
||||
}
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
||||
|
||||
@@ -46,7 +46,7 @@ cmake --build build --config Release
|
||||
```
|
||||
|
||||
- Building for Windows (x86, x64 and arm64) with MSVC or clang as compilers:
|
||||
- Install Visual Studio 2022, e.g. via the [Community Edition](https://visualstudio.microsoft.com/vs/community/). In the installer, select at least the following options (this also automatically installs the required additional tools like CMake,...):
|
||||
- Install Visual Studio 2022, e.g. via the [Community Edition](https://visualstudio.microsoft.com/de/vs/community/). In the installer, select at least the following options (this also automatically installs the required additional tools like CMake,...):
|
||||
- Tab Workload: Desktop-development with C++
|
||||
- Tab Components (select quickly via search): C++-_CMake_ Tools for Windows, _Git_ for Windows, C++-_Clang_ Compiler for Windows, MS-Build Support for LLVM-Toolset (clang)
|
||||
- Please remember to always use a Developer Command Prompt / PowerShell for VS2022 for git, build, test
|
||||
|
||||
@@ -69,7 +69,7 @@ You may want to pass in some different `ARGS`, depending on the CUDA environment
|
||||
|
||||
The defaults are:
|
||||
|
||||
- `CUDA_VERSION` set to `12.4.0`
|
||||
- `CUDA_VERSION` set to `12.6.0`
|
||||
- `CUDA_DOCKER_ARCH` set to the cmake build default, which includes all the supported architectures
|
||||
|
||||
The resulting images, are essentially the same as the non-CUDA images:
|
||||
@@ -104,7 +104,7 @@ You may want to pass in some different `ARGS`, depending on the MUSA environment
|
||||
|
||||
The defaults are:
|
||||
|
||||
- `MUSA_VERSION` set to `rc3.1.1`
|
||||
- `MUSA_VERSION` set to `rc3.1.0`
|
||||
|
||||
The resulting images, are essentially the same as the non-MUSA images:
|
||||
|
||||
|
||||
@@ -13,15 +13,13 @@ cmake -B build -DLLAMA_LLGUIDANCE=ON
|
||||
make -C build -j
|
||||
```
|
||||
|
||||
For Windows use `cmake --build build --config Release` instead of `make`.
|
||||
|
||||
This requires the Rust compiler and the `cargo` tool to be [installed](https://www.rust-lang.org/tools/install).
|
||||
|
||||
## Interface
|
||||
|
||||
There are no new command-line arguments or modifications to `common_params`. When enabled, grammars starting with `%llguidance` are passed to LLGuidance instead of the [current](../grammars/README.md) llama.cpp grammars. Additionally, JSON Schema requests (e.g., using the `-j` argument in `llama-cli`) are also passed to LLGuidance.
|
||||
|
||||
For your existing GBNF grammars, you can use [gbnf_to_lark.py script](https://github.com/guidance-ai/llguidance/blob/main/python/llguidance/gbnf_to_lark.py) to convert them to LLGuidance Lark-like format.
|
||||
For your existing GBNF grammars, you can use [gbnf_to_lark.py script](https://github.com/guidance-ai/llguidance/blob/main/scripts/gbnf_to_lark.py) to convert them to LLGuidance Lark-like format.
|
||||
|
||||
## Performance
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
|
||||
@@ -876,8 +876,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
struct test {
|
||||
static const std::string build_commit;
|
||||
static const int build_number;
|
||||
const std::string cpu_info;
|
||||
const std::string gpu_info;
|
||||
static const std::string cpu_info;
|
||||
static const std::string gpu_info;
|
||||
std::string model_filename;
|
||||
std::string model_type;
|
||||
uint64_t model_size;
|
||||
@@ -903,10 +903,7 @@ struct test {
|
||||
std::string test_time;
|
||||
std::vector<uint64_t> samples_ns;
|
||||
|
||||
test(const cmd_params_instance & inst, const llama_model * lmodel, const llama_context * ctx) :
|
||||
cpu_info(get_cpu_info()),
|
||||
gpu_info(get_gpu_info()) {
|
||||
|
||||
test(const cmd_params_instance & inst, const llama_model * lmodel, const llama_context * ctx) {
|
||||
model_filename = inst.model;
|
||||
char buf[128];
|
||||
llama_model_desc(lmodel, buf, sizeof(buf));
|
||||
@@ -1061,6 +1058,8 @@ struct test {
|
||||
|
||||
const std::string test::build_commit = LLAMA_COMMIT;
|
||||
const int test::build_number = LLAMA_BUILD_NUMBER;
|
||||
const std::string test::cpu_info = get_cpu_info();
|
||||
const std::string test::gpu_info = get_gpu_info();
|
||||
|
||||
struct printer {
|
||||
virtual ~printer() {}
|
||||
|
||||
@@ -265,14 +265,6 @@ Being experimental and unique, XTC is disabled by default. The recommended combi
|
||||
|
||||
Example usage: `--xtc-probability 0.5 --xtc-threshold 0.1`
|
||||
|
||||
### Top-nσ Sampling
|
||||
|
||||
- `--top-nsigma N`: Limit the next token selection to a subset of tokens with pre-softmax logits that are within n * σ less than the max logit (default: -1, -1 = disabled).
|
||||
|
||||
Top-nσ sampling is a text generation method that selects tokens based on a statistical threshold in pre-softmax logits. It works by only sampling from tokens with logits that are within n * σ of the maximum logit. This method helps maintain a stable sampling space regardless of temperature scaling, allowing it to perform well on reasoning tasks even in high temperatures. Without complex probability manipulation, it efficiently filters tokens directly on the pre-softmax logits. A higher value for top-nsigma (e.g., 5) will take more noisy tokens into consideration, while a lower value (e.g., 1) will focous on the more informative region of the sampling space.
|
||||
|
||||
Example usage: `--top-nsigma 1`
|
||||
|
||||
### Logit Bias
|
||||
|
||||
- `-l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS`: Modify the likelihood of a token appearing in the generated text completion.
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
|
||||
@@ -127,7 +127,6 @@ The project is under active development, and we are [looking for feedback and co
|
||||
| `--grammar-file FNAME` | file to read grammar from |
|
||||
| `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object<br/>For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead |
|
||||
| `--jinja` | Enable experimental Jinja templating engine (required for tool use) |
|
||||
| `--reasoning-format FORMAT` | Controls extraction of model thinking traces and the format / field in which they are returned (default: `deepseek`; allowed values: `deepseek`, `none`; requires `--jinja`). `none` will leave thinking traces inline in `message.content` in a model-specific format, while `deepseek` will return them separately under `message.reasoning_content` |
|
||||
|
||||
**Example-specific params**
|
||||
|
||||
@@ -1137,252 +1136,61 @@ curl http://localhost:8080/v1/chat/completions \
|
||||
|
||||
| Template | Format |
|
||||
|----------|--------|
|
||||
| Almawave-Velvet-14B.jinja | Hermes 2 Pro |
|
||||
| AtlaAI-Selene-1-Mini-Llama-3.1-8B.jinja | Llama 3.x |
|
||||
| CohereForAI-aya-expanse-8b.jinja | Generic |
|
||||
| CohereForAI-c4ai-command-r-plus-default.jinja | Generic |
|
||||
| CohereForAI-c4ai-command-r-plus-rag.jinja | Generic |
|
||||
| CohereForAI-c4ai-command-r-plus-tool_use.jinja | Generic |
|
||||
| CohereForAI-c4ai-command-r7b-12-2024-default.jinja | Command R7B (extract reasoning) |
|
||||
| CohereForAI-c4ai-command-r7b-12-2024-rag.jinja | Command R7B (extract reasoning) |
|
||||
| CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja | Command R7B (extract reasoning) |
|
||||
| CohereForAI-c4ai-command-r7b-12-2024.jinja | Generic |
|
||||
| DavieLion-Llama-3.2-1B-SPIN-iter3.jinja | Generic |
|
||||
| Delta-Vector-Rei-12B.jinja | Mistral Nemo |
|
||||
| EpistemeAI-Mistral-Nemo-Instruct-12B-Philosophy-Math.jinja | Mistral Nemo |
|
||||
| FlofloB-83k_continued_pretraining_Qwen2.5-0.5B-Instruct_Unsloth_merged_16bit.jinja | Hermes 2 Pro |
|
||||
| FlofloB-test_continued_pretraining_Phi-3-mini-4k-instruct_Unsloth_merged_16bit.jinja | Generic |
|
||||
| HelpingAI-HAI-SER.jinja | Generic |
|
||||
| HuggingFaceTB-SmolLM2-1.7B-Instruct.jinja | Generic |
|
||||
| HuggingFaceTB-SmolLM2-135M-Instruct.jinja | Generic |
|
||||
| HuggingFaceTB-SmolLM2-360M-Instruct.jinja | Generic |
|
||||
| INSAIT-Institute-BgGPT-Gemma-2-27B-IT-v1.0.jinja | Generic |
|
||||
| Ihor-Text2Graph-R1-Qwen2.5-0.5b.jinja | Hermes 2 Pro |
|
||||
| Infinigence-Megrez-3B-Instruct.jinja | Generic |
|
||||
| Josephgflowers-TinyLlama_v1.1_math_code-world-test-1.jinja | Generic |
|
||||
| LGAI-EXAONE-EXAONE-3.5-2.4B-Instruct.jinja | Generic |
|
||||
| LGAI-EXAONE-EXAONE-3.5-7.8B-Instruct.jinja | Generic |
|
||||
| LatitudeGames-Wayfarer-12B.jinja | Generic |
|
||||
| Magpie-Align-Llama-3-8B-Magpie-Align-v0.1.jinja | Generic |
|
||||
| Magpie-Align-Llama-3.1-8B-Magpie-Align-v0.1.jinja | Generic |
|
||||
| MaziyarPanahi-calme-3.2-instruct-78b.jinja | Generic |
|
||||
| MiniMaxAI-MiniMax-Text-01.jinja | Generic |
|
||||
| MiniMaxAI-MiniMax-VL-01.jinja | Generic |
|
||||
| NaniDAO-deepseek-r1-qwen-2.5-32B-ablated.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| NexaAIDev-Octopus-v2.jinja | Generic |
|
||||
| NousResearch-Hermes-2-Pro-Llama-3-8B-default.jinja | Generic |
|
||||
| NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja | Hermes 2 Pro |
|
||||
| NousResearch-Hermes-2-Pro-Mistral-7B-default.jinja | Generic |
|
||||
| NousResearch-Hermes-2-Pro-Mistral-7B-tool_use.jinja | Hermes 2 Pro |
|
||||
| NousResearch-Hermes-3-Llama-3.1-70B-default.jinja | Generic |
|
||||
| NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja | Hermes 2 Pro |
|
||||
| NovaSky-AI-Sky-T1-32B-Flash.jinja | Hermes 2 Pro |
|
||||
| NovaSky-AI-Sky-T1-32B-Preview.jinja | Hermes 2 Pro |
|
||||
| OnlyCheeini-greesychat-turbo.jinja | Generic |
|
||||
| Orenguteng-Llama-3.1-8B-Lexi-Uncensored-V2.jinja | Llama 3.x |
|
||||
| OrionStarAI-Orion-14B-Chat.jinja | Generic |
|
||||
| PowerInfer-SmallThinker-3B-Preview.jinja | Generic |
|
||||
| PrimeIntellect-INTELLECT-1-Instruct.jinja | Generic |
|
||||
| Qwen-QVQ-72B-Preview.jinja | Generic |
|
||||
| Qwen-QwQ-32B-Preview.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen1.5-7B-Chat.jinja | Generic |
|
||||
| Qwen-Qwen2-7B-Instruct.jinja | Generic |
|
||||
| Qwen-Qwen2-VL-72B-Instruct.jinja | Generic |
|
||||
| Qwen-Qwen2-VL-7B-Instruct.jinja | Generic |
|
||||
| Qwen-Qwen2.5-0.5B.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-1.5B-Instruct.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-14B-Instruct-1M.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-14B.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-32B-Instruct.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-32B.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-3B-Instruct.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-72B-Instruct.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-7B-Instruct-1M.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-7B-Instruct.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-7B.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-Coder-32B-Instruct.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-Coder-7B-Instruct.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-Math-1.5B.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-Math-7B-Instruct.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-VL-3B-Instruct.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-VL-72B-Instruct.jinja | Hermes 2 Pro |
|
||||
| Qwen-Qwen2.5-VL-7B-Instruct.jinja | Hermes 2 Pro |
|
||||
| RWKV-Red-Team-ARWKV-7B-Preview-0.1.jinja | Hermes 2 Pro |
|
||||
| SakanaAI-TinySwallow-1.5B-Instruct.jinja | Hermes 2 Pro |
|
||||
| SakanaAI-TinySwallow-1.5B.jinja | Hermes 2 Pro |
|
||||
| Sao10K-70B-L3.3-Cirrus-x1.jinja | Llama 3.x |
|
||||
| SentientAGI-Dobby-Mini-Leashed-Llama-3.1-8B.jinja | Llama 3.x |
|
||||
| SentientAGI-Dobby-Mini-Unhinged-Llama-3.1-8B.jinja | Llama 3.x |
|
||||
| Steelskull-L3.3-Damascus-R1.jinja | Llama 3.x |
|
||||
| Steelskull-L3.3-MS-Nevoria-70b.jinja | Llama 3.x |
|
||||
| Steelskull-L3.3-Nevoria-R1-70b.jinja | Llama 3.x |
|
||||
| THUDM-glm-4-9b-chat.jinja | Generic |
|
||||
| THUDM-glm-edge-1.5b-chat.jinja | Generic |
|
||||
| Tarek07-Progenitor-V1.1-LLaMa-70B.jinja | Llama 3.x |
|
||||
| TheBloke-FusionNet_34Bx2_MoE-AWQ.jinja | Generic |
|
||||
| TinyLlama-TinyLlama-1.1B-Chat-v1.0.jinja | Generic |
|
||||
| UCLA-AGI-Mistral7B-PairRM-SPPO-Iter3.jinja | Generic |
|
||||
| ValiantLabs-Llama3.1-8B-Enigma.jinja | Llama 3.x |
|
||||
| abacusai-Fewshot-Metamath-OrcaVicuna-Mistral.jinja | Generic |
|
||||
| ai21labs-AI21-Jamba-1.5-Large.jinja | Generic |
|
||||
| allenai-Llama-3.1-Tulu-3-405B-SFT.jinja | Generic |
|
||||
| allenai-Llama-3.1-Tulu-3-405B.jinja | Generic |
|
||||
| allenai-Llama-3.1-Tulu-3-8B.jinja | Generic |
|
||||
| arcee-ai-Virtuoso-Lite.jinja | Hermes 2 Pro |
|
||||
| arcee-ai-Virtuoso-Medium-v2.jinja | Hermes 2 Pro |
|
||||
| arcee-ai-Virtuoso-Small-v2.jinja | Hermes 2 Pro |
|
||||
| avemio-GRAG-NEMO-12B-ORPO-HESSIAN-AI.jinja | Generic |
|
||||
| bespokelabs-Bespoke-Stratos-7B.jinja | Hermes 2 Pro |
|
||||
| bfuzzy1-acheron-m1a-llama.jinja | Generic |
|
||||
| bofenghuang-vigogne-2-70b-chat.jinja | Generic |
|
||||
| bytedance-research-UI-TARS-72B-DPO.jinja | Generic |
|
||||
| bytedance-research-UI-TARS-7B-DPO.jinja | Generic |
|
||||
| bytedance-research-UI-TARS-7B-SFT.jinja | Generic |
|
||||
| carsenk-phi3.5_mini_exp_825_uncensored.jinja | Generic |
|
||||
| cyberagent-DeepSeek-R1-Distill-Qwen-14B-Japanese.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| cyberagent-DeepSeek-R1-Distill-Qwen-32B-Japanese.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| databricks-dbrx-instruct.jinja | Generic |
|
||||
| deepseek-ai-DeepSeek-Coder-V2-Instruct.jinja | Generic |
|
||||
| deepseek-ai-DeepSeek-Coder-V2-Lite-Base.jinja | Generic |
|
||||
| deepseek-ai-DeepSeek-Coder-V2-Lite-Instruct.jinja | Generic |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Llama-70B.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Qwen-1.5B.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Qwen-14B.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Qwen-7B.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| deepseek-ai-DeepSeek-R1-Zero.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| deepseek-ai-DeepSeek-R1.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| deepseek-ai-DeepSeek-V2-Lite.jinja | Generic |
|
||||
| deepseek-ai-DeepSeek-V2.5.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| deepseek-ai-DeepSeek-V3.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| deepseek-ai-deepseek-coder-33b-instruct.jinja | Generic |
|
||||
| deepseek-ai-deepseek-coder-6.7b-instruct.jinja | Generic |
|
||||
| deepseek-ai-deepseek-coder-7b-instruct-v1.5.jinja | Generic |
|
||||
| deepseek-ai-deepseek-llm-67b-chat.jinja | Generic |
|
||||
| deepseek-ai-deepseek-llm-7b-chat.jinja | Generic |
|
||||
| dicta-il-dictalm2.0-instruct.jinja | Generic |
|
||||
| ehristoforu-Falcon3-8B-Franken-Basestruct.jinja | Hermes 2 Pro |
|
||||
| fireworks-ai-llama-3-firefunction-v2.jinja | FireFunction v2 |
|
||||
| godlikehhd-alpaca_data_sampled_ifd_new_5200.jinja | Hermes 2 Pro |
|
||||
| godlikehhd-alpaca_data_score_max_0.7_2600.jinja | Hermes 2 Pro |
|
||||
| google-gemma-2-27b-it.jinja | Generic |
|
||||
| google-gemma-2-2b-it.jinja | Generic |
|
||||
| google-gemma-2-2b-jpn-it.jinja | Generic |
|
||||
| google-gemma-7b-it.jinja | Generic |
|
||||
| huihui-ai-DeepSeek-R1-Distill-Llama-70B-abliterated.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| huihui-ai-DeepSeek-R1-Distill-Llama-8B-abliterated.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| huihui-ai-DeepSeek-R1-Distill-Qwen-14B-abliterated-v2.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| huihui-ai-DeepSeek-R1-Distill-Qwen-32B-abliterated.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| huihui-ai-DeepSeek-R1-Distill-Qwen-7B-abliterated-v2.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| huihui-ai-Qwen2.5-14B-Instruct-1M-abliterated.jinja | Hermes 2 Pro |
|
||||
| ibm-granite-granite-3.1-8b-instruct.jinja | Generic |
|
||||
| indischepartij-MiniCPM-3B-OpenHermes-2.5-v2.jinja | Generic |
|
||||
| inflatebot-MN-12B-Mag-Mell-R1.jinja | Generic |
|
||||
| jinaai-ReaderLM-v2.jinja | Generic |
|
||||
| kms7530-chemeng_qwen-math-7b_24_1_100_1_nonmath.jinja | Hermes 2 Pro |
|
||||
| knifeayumu-Cydonia-v1.3-Magnum-v4-22B.jinja | Mistral Nemo |
|
||||
| langgptai-qwen1.5-7b-chat-sa-v0.1.jinja | Generic |
|
||||
| lightblue-DeepSeek-R1-Distill-Qwen-7B-Japanese.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| mattshumer-Reflection-Llama-3.1-70B.jinja | Generic |
|
||||
| meetkai-functionary-medium-v3.1.jinja | Functionary v3.1 Llama 3.1 |
|
||||
| meetkai-functionary-medium-v3.2.jinja | Functionary v3.2 |
|
||||
| meta-llama-Llama-2-7b-chat-hf.jinja | Generic |
|
||||
| meta-llama-Llama-3.1-8B-Instruct.jinja | Llama 3.x |
|
||||
| meta-llama-Llama-3.2-11B-Vision-Instruct.jinja | Llama 3.x |
|
||||
| meta-llama-Llama-3.2-1B-Instruct.jinja | Llama 3.x |
|
||||
| meta-llama-Llama-3.2-3B-Instruct.jinja | Llama 3.x |
|
||||
| meta-llama-Llama-3.3-70B-Instruct.jinja | Llama 3.x |
|
||||
| meta-llama-Meta-Llama-3-8B-Instruct.jinja | Generic |
|
||||
| meta-llama-Meta-Llama-3.1-8B-Instruct.jinja | Llama 3.x |
|
||||
| microsoft-Phi-3-medium-4k-instruct.jinja | Generic |
|
||||
| microsoft-Phi-3-mini-4k-instruct.jinja | Generic |
|
||||
| microsoft-Phi-3-small-8k-instruct.jinja | Generic |
|
||||
| microsoft-Phi-3.5-mini-instruct.jinja | Generic |
|
||||
| microsoft-Phi-3.5-vision-instruct.jinja | Generic |
|
||||
| microsoft-phi-4.jinja | Generic |
|
||||
| migtissera-Tess-3-Mistral-Nemo-12B.jinja | Generic |
|
||||
| ministral-Ministral-3b-instruct.jinja | Generic |
|
||||
| mistralai-Codestral-22B-v0.1.jinja | Generic |
|
||||
| mistralai-Mistral-7B-Instruct-v0.1.jinja | Generic |
|
||||
| mistralai-Mistral-7B-Instruct-v0.2.jinja | Generic |
|
||||
| mistralai-Mistral-7B-Instruct-v0.3.jinja | Mistral Nemo |
|
||||
| mistralai-Mistral-Large-Instruct-2407.jinja | Mistral Nemo |
|
||||
| mistralai-Mistral-Large-Instruct-2411.jinja | Generic |
|
||||
| mistralai-Mistral-Nemo-Instruct-2407.jinja | Mistral Nemo |
|
||||
| mistralai-Mistral-Small-24B-Instruct-2501.jinja | Generic |
|
||||
| mistralai-Mixtral-8x7B-Instruct-v0.1.jinja | Generic |
|
||||
| mkurman-Qwen2.5-14B-DeepSeek-R1-1M.jinja | Hermes 2 Pro |
|
||||
| mlabonne-AlphaMonarch-7B.jinja | Generic |
|
||||
| mlx-community-Josiefied-Qwen2.5-0.5B-Instruct-abliterated-v1-float32.jinja | Hermes 2 Pro |
|
||||
| mlx-community-Qwen2.5-VL-7B-Instruct-8bit.jinja | Hermes 2 Pro |
|
||||
| mobiuslabsgmbh-DeepSeek-R1-ReDistill-Qwen-1.5B-v1.1.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| netcat420-MFANNv0.20.jinja | Generic |
|
||||
| netcat420-MFANNv0.24.jinja | Generic |
|
||||
| netease-youdao-Confucius-o1-14B.jinja | Hermes 2 Pro |
|
||||
| nvidia-AceMath-7B-RM.jinja | Hermes 2 Pro |
|
||||
| nvidia-Eagle2-1B.jinja | Hermes 2 Pro |
|
||||
| nvidia-Eagle2-9B.jinja | Hermes 2 Pro |
|
||||
| nvidia-Llama-3.1-Nemotron-70B-Instruct-HF.jinja | Llama 3.x |
|
||||
| onnx-community-DeepSeek-R1-Distill-Qwen-1.5B-ONNX.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| open-thoughts-OpenThinker-7B.jinja | Hermes 2 Pro |
|
||||
| openchat-openchat-3.5-0106.jinja | Generic |
|
||||
| pankajmathur-orca_mini_v6_8b.jinja | Generic |
|
||||
| princeton-nlp-Mistral-7B-Base-SFT-RDPO.jinja | Generic |
|
||||
| princeton-nlp-Mistral-7B-Instruct-DPO.jinja | Generic |
|
||||
| princeton-nlp-Mistral-7B-Instruct-RDPO.jinja | Generic |
|
||||
| prithivMLmods-Bellatrix-Tiny-1.5B-R1.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Bellatrix-Tiny-1B-R1.jinja | Llama 3.x |
|
||||
| prithivMLmods-Bellatrix-Tiny-1B-v3.jinja | Generic |
|
||||
| prithivMLmods-Bellatrix-Tiny-3B-R1.jinja | Llama 3.x |
|
||||
| prithivMLmods-Blaze-14B-xElite.jinja | Generic |
|
||||
| prithivMLmods-Calcium-Opus-14B-Elite2-R1.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Calme-Ties-78B.jinja | Generic |
|
||||
| prithivMLmods-Calme-Ties2-78B.jinja | Generic |
|
||||
| prithivMLmods-Calme-Ties3-78B.jinja | Generic |
|
||||
| prithivMLmods-ChemQwen2-vL.jinja | Generic |
|
||||
| prithivMLmods-GWQ2b.jinja | Generic |
|
||||
| prithivMLmods-LatexMind-2B-Codec.jinja | Generic |
|
||||
| prithivMLmods-Llama-3.2-6B-AlgoCode.jinja | Llama 3.x |
|
||||
| prithivMLmods-Megatron-Opus-14B-Exp.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Megatron-Opus-14B-Stock.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Megatron-Opus-7B-Exp.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Omni-Reasoner-Merged.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Omni-Reasoner4-Merged.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Primal-Opus-14B-Optimus-v1.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-QwQ-Math-IO-500M.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Qwen-7B-Distill-Reasoner.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| prithivMLmods-Qwen2.5-1.5B-DeepSeek-R1-Instruct.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Qwen2.5-14B-DeepSeek-R1-1M.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Qwen2.5-32B-DeepSeek-R1-Instruct.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Qwen2.5-7B-DeepSeek-R1-1M.jinja | Hermes 2 Pro |
|
||||
| prithivMLmods-Triangulum-v2-10B.jinja | Hermes 2 Pro |
|
||||
| qingy2024-Falcon3-2x10B-MoE-Instruct.jinja | Hermes 2 Pro |
|
||||
| rubenroy-Zurich-14B-GCv2-5m.jinja | Hermes 2 Pro |
|
||||
| rubenroy-Zurich-7B-GCv2-5m.jinja | Hermes 2 Pro |
|
||||
| silma-ai-SILMA-Kashif-2B-Instruct-v1.0.jinja | Generic |
|
||||
| simplescaling-s1-32B.jinja | Hermes 2 Pro |
|
||||
| sometimesanotion-Lamarck-14B-v0.7.jinja | Hermes 2 Pro |
|
||||
| sonthenguyen-zephyr-sft-bnb-4bit-DPO-mtbr-180steps.jinja | Generic |
|
||||
| sthenno-tempesthenno-icy-0130.jinja | Generic |
|
||||
| sumink-qwft.jinja | Hermes 2 Pro |
|
||||
| teknium-OpenHermes-2.5-Mistral-7B.jinja | Generic |
|
||||
| thirdeyeai-elevate360m.jinja | Generic |
|
||||
| tiiuae-Falcon3-10B-Instruct.jinja | Hermes 2 Pro |
|
||||
| unsloth-DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| unsloth-DeepSeek-R1-Distill-Llama-8B.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| unsloth-DeepSeek-R1.jinja | DeepSeek R1 (extract reasoning) |
|
||||
| unsloth-Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit.jinja | Generic |
|
||||
| upstage-solar-pro-preview-instruct.jinja | Generic |
|
||||
| whyhow-ai-PatientSeek.jinja | Generic |
|
||||
| xwen-team-Xwen-72B-Chat.jinja | Hermes 2 Pro |
|
||||
| xwen-team-Xwen-7B-Chat.jinja | Hermes 2 Pro |
|
||||
| CohereForAI-c4ai-command-r-plus-default.jinja | generic tool calls |
|
||||
| CohereForAI-c4ai-command-r-plus-rag.jinja | generic tool calls |
|
||||
| CohereForAI-c4ai-command-r-plus-tool_use.jinja | generic tool calls |
|
||||
| MiniMaxAI-MiniMax-Text-01.jinja | generic tool calls |
|
||||
| NexaAIDev-Octopus-v2.jinja | generic tool calls |
|
||||
| NousResearch-Hermes-2-Pro-Llama-3-8B-default.jinja | generic tool calls |
|
||||
| NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja | hermes 2 pro tool calls |
|
||||
| NousResearch-Hermes-2-Pro-Mistral-7B-default.jinja | generic tool calls |
|
||||
| NousResearch-Hermes-2-Pro-Mistral-7B-tool_use.jinja | hermes 2 pro tool calls |
|
||||
| NousResearch-Hermes-3-Llama-3.1-70B-default.jinja | generic tool calls |
|
||||
| NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja | hermes 2 pro tool calls |
|
||||
| OrionStarAI-Orion-14B-Chat.jinja | generic tool calls |
|
||||
| Qwen-QwQ-32B-Preview.jinja | hermes 2 pro tool calls |
|
||||
| Qwen-Qwen2-7B-Instruct.jinja | generic tool calls |
|
||||
| Qwen-Qwen2-VL-7B-Instruct.jinja | generic tool calls |
|
||||
| Qwen-Qwen2.5-7B-Instruct.jinja | hermes 2 pro tool calls |
|
||||
| Qwen-Qwen2.5-Math-7B-Instruct.jinja | hermes 2 pro tool calls |
|
||||
| TheBloke-FusionNet_34Bx2_MoE-AWQ.jinja | generic tool calls |
|
||||
| abacusai-Fewshot-Metamath-OrcaVicuna-Mistral.jinja | generic tool calls |
|
||||
| bofenghuang-vigogne-2-70b-chat.jinja | generic tool calls |
|
||||
| databricks-dbrx-instruct.jinja | generic tool calls |
|
||||
| deepseek-ai-DeepSeek-Coder-V2-Instruct.jinja | generic tool calls |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja | deepseek r1 tool calls |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja | deepseek r1 tool calls |
|
||||
| deepseek-ai-DeepSeek-R1-Distill-Qwen-7B.jinja | deepseek r1 tool calls |
|
||||
| deepseek-ai-DeepSeek-V2.5.jinja | deepseek r1 tool calls |
|
||||
| deepseek-ai-deepseek-coder-33b-instruct.jinja | generic tool calls |
|
||||
| google-gemma-2-2b-it.jinja | generic tool calls |
|
||||
| google-gemma-7b-it.jinja | generic tool calls |
|
||||
| indischepartij-MiniCPM-3B-OpenHermes-2.5-v2.jinja | generic tool calls |
|
||||
| mattshumer-Reflection-Llama-3.1-70B.jinja | generic tool calls |
|
||||
| meetkai-functionary-medium-v3.2.jinja | functionary v3.2 tool calls |
|
||||
| meta-llama-Llama-3.1-8B-Instruct.jinja | llama 3.x tool calls (w/ builtin tools) |
|
||||
| meta-llama-Llama-3.2-3B-Instruct.jinja | llama 3.x tool calls |
|
||||
| meta-llama-Llama-3.3-70B-Instruct.jinja | llama 3.x tool calls (w/ builtin tools) |
|
||||
| meta-llama-Meta-Llama-3.1-8B-Instruct.jinja | llama 3.x tool calls (w/ builtin tools) |
|
||||
| microsoft-Phi-3-medium-4k-instruct.jinja | generic tool calls |
|
||||
| microsoft-Phi-3-mini-4k-instruct.jinja | generic tool calls |
|
||||
| microsoft-Phi-3-small-8k-instruct.jinja | generic tool calls |
|
||||
| microsoft-Phi-3.5-mini-instruct.jinja | generic tool calls |
|
||||
| microsoft-Phi-3.5-vision-instruct.jinja | generic tool calls |
|
||||
| mistralai-Mistral-7B-Instruct-v0.2.jinja | generic tool calls |
|
||||
| mistralai-Mistral-Large-Instruct-2407.jinja | mistral nemo tool calls |
|
||||
| mistralai-Mistral-Large-Instruct-2411.jinja | generic tool calls |
|
||||
| mistralai-Mistral-Nemo-Instruct-2407.jinja | mistral nemo tool calls |
|
||||
| mistralai-Mixtral-8x7B-Instruct-v0.1.jinja | generic tool calls |
|
||||
| mlabonne-AlphaMonarch-7B.jinja | generic tool calls |
|
||||
| nvidia-Llama-3.1-Nemotron-70B-Instruct-HF.jinja | llama 3.x tool calls (w/ builtin tools) |
|
||||
| openchat-openchat-3.5-0106.jinja | generic tool calls |
|
||||
| teknium-OpenHermes-2.5-Mistral-7B.jinja | generic tool calls |
|
||||
|
||||
This table can be generated with:
|
||||
|
||||
```bash
|
||||
./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
@@ -1394,20 +1202,11 @@ curl http://localhost:8080/v1/chat/completions \
|
||||
|
||||
```shell
|
||||
# Native support:
|
||||
|
||||
llama-server --jinja -fa -hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M
|
||||
llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q6_K_L
|
||||
llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M
|
||||
llama-server --jinja -fa -hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M
|
||||
|
||||
# Native support for DeepSeek R1 works best w/ our own template (official template buggy)
|
||||
|
||||
llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q6_K_L \
|
||||
--chat-template-file models/templates/llama-cpp-deepseek-r1.jinja
|
||||
|
||||
llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-32B-GGUF:Q4_K_M \
|
||||
--chat-template-file models/templates/llama-cpp-deepseek-r1.jinja
|
||||
|
||||
# Native support requires the right template for these GGUFs:
|
||||
|
||||
llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \
|
||||
@@ -1437,17 +1236,17 @@ curl http://localhost:8080/v1/chat/completions \
|
||||
{
|
||||
"type":"function",
|
||||
"function":{
|
||||
"name":"python",
|
||||
"description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
|
||||
"name":"get_current_weather",
|
||||
"description":"Get the current weather in a given location",
|
||||
"parameters":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"code":{
|
||||
"location":{
|
||||
"type":"string",
|
||||
"description":"The code to run in the ipython interpreter."
|
||||
"description":"The city and state, e.g. San Francisco, CA"
|
||||
}
|
||||
},
|
||||
"required":["code"]
|
||||
"required":["location"]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1455,7 +1254,7 @@ curl http://localhost:8080/v1/chat/completions \
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Print a hello world message with python."
|
||||
"content": "What is the weather like in Istanbul?."
|
||||
}
|
||||
]
|
||||
}'
|
||||
|
||||
Binary file not shown.
@@ -173,7 +173,6 @@ struct slot_params {
|
||||
{"grammar_trigger_words", grammar_trigger_words},
|
||||
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
|
||||
{"preserved_tokens", sampling.preserved_tokens},
|
||||
{"chat_format", common_chat_format_name(oaicompat_chat_format)},
|
||||
{"samplers", samplers},
|
||||
{"speculative.n_max", speculative.n_max},
|
||||
{"speculative.n_min", speculative.n_min},
|
||||
@@ -725,19 +724,9 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
msg.content = content;
|
||||
}
|
||||
|
||||
json message {
|
||||
{"role", "assistant"},
|
||||
};
|
||||
if (!msg.reasoning_content.empty()) {
|
||||
message["reasoning_content"] = msg.reasoning_content;
|
||||
}
|
||||
if (msg.content.empty() && !msg.tool_calls.empty()) {
|
||||
message["content"] = json();
|
||||
} else {
|
||||
message["content"] = msg.content;
|
||||
}
|
||||
json tool_calls;
|
||||
if (!msg.tool_calls.empty()) {
|
||||
auto tool_calls = json::array();
|
||||
tool_calls = json::array();
|
||||
for (const auto & tc : msg.tool_calls) {
|
||||
tool_calls.push_back({
|
||||
{"type", "function"},
|
||||
@@ -748,7 +737,15 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
{"id", tc.id},
|
||||
});
|
||||
}
|
||||
message["tool_calls"] = tool_calls;
|
||||
}
|
||||
|
||||
json message {
|
||||
{"content", msg.content},
|
||||
{"tool_calls", tool_calls},
|
||||
{"role", "assistant"},
|
||||
};
|
||||
if (!msg.tool_plan.empty()) {
|
||||
message["tool_plan"] = msg.tool_plan;
|
||||
}
|
||||
|
||||
json choice {
|
||||
@@ -2076,8 +2073,8 @@ struct server_context {
|
||||
|
||||
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
||||
// Might be better to reject the request with a 400 ?
|
||||
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.params.n_predict, slot.n_predict);
|
||||
slot.params.n_predict = slot.n_predict;
|
||||
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
|
||||
}
|
||||
|
||||
if (slot.params.ignore_eos && has_eos_token) {
|
||||
@@ -4063,7 +4060,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
auto body = json::parse(req.body);
|
||||
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
|
||||
json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates);
|
||||
|
||||
return handle_completions_impl(
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
@@ -4076,7 +4073,7 @@ int main(int argc, char ** argv) {
|
||||
// same with handle_chat_completions, but without inference part
|
||||
const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
auto body = json::parse(req.body);
|
||||
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
|
||||
json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates);
|
||||
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
|
||||
};
|
||||
|
||||
|
||||
@@ -92,7 +92,6 @@ def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, a
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}'
|
||||
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
|
||||
assert expected_function_name == tool_call["function"]["name"]
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
@@ -156,11 +155,11 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
# (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
# (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
@@ -176,7 +175,7 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
# (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
# TODO: fix these
|
||||
# (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
# (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
@@ -215,7 +214,6 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}'
|
||||
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
|
||||
assert expected_function_name == tool_call["function"]["name"]
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
@@ -275,6 +273,7 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("hf_repo,template_override", [
|
||||
("bartowski/c4ai-command-r7b-12-2024-GGUF:Q4_K_M", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")),
|
||||
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
@@ -299,16 +298,13 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
|
||||
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")),
|
||||
|
||||
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
|
||||
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
|
||||
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||
|
||||
# ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
# ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None):
|
||||
global server
|
||||
n_predict = 512
|
||||
server.n_slots = 1
|
||||
@@ -327,7 +323,6 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
|
||||
{"role": "user", "content": "What is the weather in Istanbul?"},
|
||||
],
|
||||
"tools": [WEATHER_TOOL],
|
||||
@@ -337,7 +332,6 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}'
|
||||
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"]
|
||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
|
||||
@@ -346,166 +340,22 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
|
||||
assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
|
||||
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||
(None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
|
||||
(None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
(None, 128, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(None, 128, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
|
||||
(None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
("^> 0.56$", 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
|
||||
# TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
|
||||
("^The y-coordinate [\\s\\S]*?\\*\\*0.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
("[\\s\\S]*?\\*\\*0\\.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
])
|
||||
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
global server
|
||||
# n_predict = 512
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192 * 2
|
||||
server.n_predict = n_predict
|
||||
server.model_hf_repo = hf_repo
|
||||
server.model_hf_file = None
|
||||
if isinstance(template_override, tuple):
|
||||
(template_hf_repo, template_variant) = template_override
|
||||
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things, and provide very concise answers. Do not explain your reasoning to the user. Provide any numerical values back to the user with at most two decimals."},
|
||||
{"role": "user", "content": "What's the y coordinate of a point on the unit sphere at angle 30 degrees?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_6789",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculate",
|
||||
"arguments": "{\"expression\":\"sin(30 * pi / 180)\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "calculate",
|
||||
"content": 0.55644242476,
|
||||
"tool_call_id": "call_6789"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type":"function",
|
||||
"function":{
|
||||
"name":"calculate",
|
||||
"description":"A calculator function that computes values of arithmetic expressions in the Python syntax",
|
||||
"parameters":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"expression":{
|
||||
"type":"string",
|
||||
"description":"An arithmetic expression to compute the value of (Python syntad, assuming all floats)"
|
||||
}
|
||||
},
|
||||
"required":["expression"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls is None, f'Expected no tool call in {choice["message"]}'
|
||||
content = choice["message"].get("content")
|
||||
assert content is not None, f'Expected content in {choice["message"]}'
|
||||
if result_override is not None:
|
||||
assert re.match(result_override, content), f'Expected {result_override}, got {content}'
|
||||
else:
|
||||
assert re.match('^[\\s\\S]*?The (y[ -])?coordinate [\\s\\S]*?is (approximately )?0\\.56\\b|^0\\.56$', content), \
|
||||
f'Expected something like "The y coordinate is 0.56.", got {content}'
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("n_predict,reasoning_format,expect_content,expect_reasoning_content,hf_repo,template_override", [
|
||||
(128, 'deepseek', "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(128, None, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
|
||||
(1024, 'deepseek', "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(1024, 'none', "<think>\n?I need[\\s\\S]*?</think>\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
|
||||
(1024, 'deepseek', "To find the sum of.*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
])
|
||||
def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
global server
|
||||
server.n_slots = 1
|
||||
server.reasoning_format = reasoning_format
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192 * 2
|
||||
server.n_predict = n_predict
|
||||
server.model_hf_repo = hf_repo
|
||||
server.model_hf_file = None
|
||||
if isinstance(template_override, tuple):
|
||||
(template_hf_repo, template_variant) = template_override
|
||||
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the sum of 102 and 7?"},
|
||||
]
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
|
||||
|
||||
content = choice["message"].get("content")
|
||||
if expect_content is None:
|
||||
assert content is None, f'Expected no content in {choice["message"]}'
|
||||
else:
|
||||
assert re.match(expect_content, content), f'Expected {expect_content}, got {content}'
|
||||
|
||||
reasoning_content = choice["message"].get("reasoning_content")
|
||||
if expect_reasoning_content is None:
|
||||
assert reasoning_content is None, f'Expected no reasoning content in {choice["message"]}'
|
||||
else:
|
||||
assert re.match(expect_reasoning_content, reasoning_content), f'Expected {expect_reasoning_content}, got {reasoning_content}'
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [
|
||||
(None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
# (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)),
|
||||
(None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
|
||||
|
||||
('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
(None, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
(None, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
('{"code":"print("}', "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
(None, "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
|
||||
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
@@ -521,13 +371,15 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
|
||||
|
||||
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
|
||||
(None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||
|
||||
# (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_hello_world(expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
global server
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = 512 # High because of DeepSeek R1
|
||||
server.n_predict = 128
|
||||
server.model_hf_repo = hf_repo
|
||||
server.model_hf_file = None
|
||||
if isinstance(template_override, tuple):
|
||||
@@ -554,7 +406,6 @@ def test_hello_world(expected_arguments_override: str | None, hf_repo: str, temp
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}'
|
||||
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
if expected_arguments_override is not None:
|
||||
|
||||
@@ -78,7 +78,6 @@ class ServerProcess:
|
||||
draft_max: int | None = None
|
||||
no_webui: bool | None = None
|
||||
jinja: bool | None = None
|
||||
reasoning_format: Literal['deepseek', 'none'] | None = None
|
||||
chat_template: str | None = None
|
||||
chat_template_file: str | None = None
|
||||
|
||||
@@ -173,8 +172,6 @@ class ServerProcess:
|
||||
server_args.append("--no-webui")
|
||||
if self.jinja:
|
||||
server_args.append("--jinja")
|
||||
if self.reasoning_format is not None:
|
||||
server_args.extend(("--reasoning-format", self.reasoning_format))
|
||||
if self.chat_template:
|
||||
server_args.extend(["--chat-template", self.chat_template])
|
||||
if self.chat_template_file:
|
||||
|
||||
@@ -578,7 +578,6 @@ static json oaicompat_completion_params_parse(const json & body) {
|
||||
static json oaicompat_completion_params_parse(
|
||||
const json & body, /* openai api json semantics */
|
||||
bool use_jinja,
|
||||
common_reasoning_format reasoning_format,
|
||||
const common_chat_templates & chat_templates)
|
||||
{
|
||||
json llama_params;
|
||||
@@ -634,10 +633,9 @@ static json oaicompat_completion_params_parse(
|
||||
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
||||
}
|
||||
common_chat_inputs inputs;
|
||||
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
inputs.messages = body.at("messages");
|
||||
inputs.tools = tools;
|
||||
inputs.tool_choice = tool_choice;
|
||||
inputs.messages = body.at("messages");
|
||||
inputs.tools = tools;
|
||||
inputs.tool_choice = tool_choice;
|
||||
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
||||
if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
|
||||
LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
|
||||
|
||||
@@ -254,12 +254,12 @@ export default function ChatMessage({
|
||||
🔄 Regenerate
|
||||
</button>
|
||||
)}
|
||||
<CopyButton
|
||||
className="badge btn-mini show-on-hover mr-2"
|
||||
content={msg.content}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
<CopyButton
|
||||
className="badge btn-mini show-on-hover mr-2"
|
||||
content={msg.content}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,8 +7,10 @@
|
||||
#include "ggml-cpu-impl.h"
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-quants.h"
|
||||
#include "ggml-cpu-quants.h"
|
||||
#include "ggml-threading.h"
|
||||
#include "amx/amx.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
@@ -1289,7 +1291,7 @@ struct ggml_threadpool {
|
||||
atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
|
||||
atomic_int GGML_CACHE_ALIGN n_barrier;
|
||||
atomic_int GGML_CACHE_ALIGN n_barrier_passed;
|
||||
atomic_int GGML_CACHE_ALIGN current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
|
||||
atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
|
||||
|
||||
// these are atomic as an annotation for thread-sanitizer
|
||||
atomic_bool stop; // Used for stopping the threadpool altogether
|
||||
@@ -7488,7 +7490,6 @@ UseGgmlGemm1:;
|
||||
if (src1->type != vec_dot_type) {
|
||||
char * wdata = params->wdata;
|
||||
|
||||
const size_t nbw0 = ggml_type_size(vec_dot_type);
|
||||
const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
|
||||
const size_t nbw2 = nbw1*ne11;
|
||||
const size_t nbw3 = nbw2*ne12;
|
||||
@@ -7496,7 +7497,6 @@ UseGgmlGemm1:;
|
||||
assert(params->wsize >= ne13*nbw3);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
#if 0
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
|
||||
@@ -7506,20 +7506,6 @@ UseGgmlGemm1:;
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
size_t bs = ggml_blck_size(vec_dot_type);
|
||||
int64_t ne10_block_start = (ith * ne10/bs) / nth;
|
||||
int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth;
|
||||
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
|
||||
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
|
||||
(ne10_block_end - ne10_block_start) * bs);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
if (ith == 0) {
|
||||
@@ -7607,6 +7593,7 @@ UseGgmlGemm2:;
|
||||
if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
|
||||
num_rows_per_vec_dot = 1;
|
||||
}
|
||||
|
||||
ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
|
||||
|
||||
if (nth >= nchunk0 * nchunk1) {
|
||||
@@ -7619,84 +7606,6 @@ UseGgmlGemm2:;
|
||||
|
||||
// ggml_compute_forward_mul_mat_id
|
||||
|
||||
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ids->ne[0]*ids->ne[1] + (i1)]
|
||||
|
||||
struct mmid_row_mapping {
|
||||
int32_t i1;
|
||||
int32_t i2;
|
||||
};
|
||||
|
||||
static void ggml_compute_forward_mul_mat_id_one_chunk(
|
||||
struct ggml_tensor * dst,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
const struct ggml_tensor * ids,
|
||||
const int64_t cur_a,
|
||||
const int64_t ir0_start,
|
||||
const int64_t ir0_end,
|
||||
const int64_t ir1_start,
|
||||
const int64_t ir1_end,
|
||||
const char * src0_cur,
|
||||
const struct mmid_row_mapping * matrix_rows,
|
||||
const size_t row_size,
|
||||
const bool src1_cont,
|
||||
const void * wdata) {
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const enum ggml_type type = src0->type;
|
||||
|
||||
ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
|
||||
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
|
||||
|
||||
const int64_t blck_0 = 16;
|
||||
const int64_t blck_1 = 16;
|
||||
|
||||
float tmp[16];
|
||||
|
||||
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
||||
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
||||
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ++ir1) {
|
||||
const int64_t _i12 = ir1; // logical row index for this expert
|
||||
|
||||
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
|
||||
const int id = row_mapping.i1; // selected expert index
|
||||
|
||||
const int64_t i11 = id % ne11;
|
||||
const int64_t i12 = row_mapping.i2; // row index in src1
|
||||
|
||||
const int64_t i1 = id; // selected expert index
|
||||
const int64_t i2 = i12; // row
|
||||
|
||||
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
|
||||
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
|
||||
// the original src1 data pointer, so we should index using the indices directly
|
||||
// TODO: this is a bit of a hack, we should probably have a better way to handle this
|
||||
const char * src1_col = (const char *) wdata +
|
||||
(src1_cont || src1->type != vec_dot_type
|
||||
? (i11 + i12*ne11)*row_size
|
||||
: (i11*nb11 + i12*nb12));
|
||||
|
||||
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
|
||||
|
||||
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
|
||||
vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
|
||||
}
|
||||
|
||||
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
|
||||
|
||||
void * ptr = *p;
|
||||
ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
|
||||
*p = (void *) ((char *) ptr + size);
|
||||
return ptr;
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_mul_mat_id(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
@@ -7714,6 +7623,7 @@ static void ggml_compute_forward_mul_mat_id(
|
||||
|
||||
const bool src1_cont = ggml_is_contiguous(src1);
|
||||
|
||||
ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
|
||||
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
|
||||
ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
|
||||
|
||||
@@ -7731,27 +7641,21 @@ static void ggml_compute_forward_mul_mat_id(
|
||||
const int n_ids = ids->ne[0]; // n_expert_used
|
||||
const int n_as = ne02; // n_expert
|
||||
|
||||
void * wdata_cur = params->wdata;
|
||||
char * wdata_src1_end = (src1->type == vec_dot_type) ?
|
||||
(char *) params->wdata :
|
||||
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
|
||||
|
||||
if (src1->type != vec_dot_type) {
|
||||
incr_ptr_aligned(&wdata_cur, ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
|
||||
}
|
||||
struct mmid_row_mapping {
|
||||
int32_t i1;
|
||||
int32_t i2;
|
||||
};
|
||||
|
||||
int64_t * matrix_row_counts = // [n_as]
|
||||
incr_ptr_aligned(&wdata_cur, n_as*sizeof(int64_t), sizeof(int64_t));
|
||||
|
||||
struct mmid_row_mapping * matrix_rows = // [n_as][ids->ne[0]*ids->ne[1]]
|
||||
incr_ptr_aligned(&wdata_cur, n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping), sizeof(int64_t));
|
||||
|
||||
char (*atomic_current_chunk)[CACHE_LINE_SIZE] = // [n_as]
|
||||
incr_ptr_aligned(&wdata_cur, CACHE_LINE_SIZE * n_as, CACHE_LINE_SIZE);
|
||||
|
||||
GGML_ASSERT(params->wsize >= (size_t)((char *) wdata_cur - (char *) params->wdata));
|
||||
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
|
||||
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
|
||||
|
||||
if (src1->type != vec_dot_type) {
|
||||
char * wdata = params->wdata;
|
||||
|
||||
const size_t nbw0 = ggml_type_size(vec_dot_type);
|
||||
const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
|
||||
const size_t nbw2 = nbw1*ne11;
|
||||
const size_t nbw3 = nbw2*ne12;
|
||||
@@ -7759,32 +7663,19 @@ static void ggml_compute_forward_mul_mat_id(
|
||||
assert(params->wsize >= ne13*nbw3);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
#if 0
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = ith; i12 < ne12; i12 += nth) {
|
||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
|
||||
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
||||
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
|
||||
ne10);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
size_t bs = ggml_blck_size(vec_dot_type);
|
||||
int64_t ne10_block_start = (ith * ne10/bs) / nth;
|
||||
int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth;
|
||||
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
|
||||
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
|
||||
(ne10_block_end - ne10_block_start) * bs);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
|
||||
|
||||
if (ith == 0) {
|
||||
// initialize matrix_row_counts
|
||||
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
|
||||
@@ -7802,14 +7693,9 @@ static void ggml_compute_forward_mul_mat_id(
|
||||
}
|
||||
}
|
||||
|
||||
// reset current_chunk
|
||||
for (int cur_a = ith; cur_a < n_as; cur_a += nth) {
|
||||
atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
|
||||
*current_chunk_ctr = nth;
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
// compute each matrix multiplication in sequence
|
||||
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
|
||||
const int64_t cne1 = matrix_row_counts[cur_a];
|
||||
|
||||
@@ -7817,64 +7703,84 @@ static void ggml_compute_forward_mul_mat_id(
|
||||
continue;
|
||||
}
|
||||
|
||||
const char * src0_cur = (const char *) src0->data + cur_a * nb02;
|
||||
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||
const char * src0_cur = (const char *) src0->data + cur_a*nb02;
|
||||
|
||||
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||
|
||||
const int64_t nr0 = ne01;
|
||||
const int64_t nr1 = cne1;
|
||||
const int64_t nr0 = ne01; // src0 rows
|
||||
const int64_t nr1 = cne1; // src1 rows
|
||||
|
||||
int chunk_size = 16;
|
||||
if (nr0 == 1 || nr1 == 1) {
|
||||
chunk_size = 64;
|
||||
}
|
||||
// distribute the thread work across the inner or outer loop based on which one is larger
|
||||
|
||||
#if defined(__aarch64__)
|
||||
// disable for ARM
|
||||
const bool disable_chunking = true;
|
||||
#else
|
||||
// disable for NUMA
|
||||
const bool disable_chunking = ggml_is_numa();
|
||||
#endif // defined(__aarch64__)
|
||||
const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
|
||||
const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
|
||||
|
||||
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
|
||||
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
|
||||
const int64_t ith0 = ith % nth0;
|
||||
const int64_t ith1 = ith / nth0;
|
||||
|
||||
if (nchunk0 * nchunk1 < nth * 4 || disable_chunking) {
|
||||
nchunk0 = nr0 > nr1 ? nth : 1;
|
||||
nchunk1 = nr0 > nr1 ? 1 : nth;
|
||||
}
|
||||
const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
|
||||
const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
|
||||
|
||||
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
|
||||
const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
|
||||
const int64_t ir010 = dr0*ith0;
|
||||
const int64_t ir011 = MIN(ir010 + dr0, nr0);
|
||||
|
||||
int current_chunk = ith;
|
||||
const int64_t ir110 = dr1*ith1;
|
||||
const int64_t ir111 = MIN(ir110 + dr1, nr1);
|
||||
|
||||
atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
|
||||
// threads with no work simply yield (not sure if it helps)
|
||||
//if (ir010 >= ir011 || ir110 >= ir111) {
|
||||
// sched_yield();
|
||||
// continue;
|
||||
//}
|
||||
|
||||
while (current_chunk < nchunk0 * nchunk1) {
|
||||
const int64_t ith0 = current_chunk % nchunk0;
|
||||
const int64_t ith1 = current_chunk / nchunk0;
|
||||
// block-tiling attempt
|
||||
const int64_t blck_0 = 16;
|
||||
const int64_t blck_1 = 16;
|
||||
|
||||
const int64_t ir0_start = dr0 * ith0;
|
||||
const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
|
||||
// attempt to reduce false-sharing (does not seem to make a difference)
|
||||
float tmp[16];
|
||||
|
||||
const int64_t ir1_start = dr1 * ith1;
|
||||
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
|
||||
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
|
||||
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
|
||||
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
|
||||
const int64_t _i12 = ir1; // logical row index for this expert
|
||||
|
||||
ggml_compute_forward_mul_mat_id_one_chunk(
|
||||
dst, src0, src1, ids, cur_a,
|
||||
ir0_start, ir0_end, ir1_start, ir1_end,
|
||||
src0_cur, matrix_rows, row_size, src1_cont, wdata
|
||||
);
|
||||
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
|
||||
const int id = row_mapping.i1; // selected expert index
|
||||
|
||||
if (nth >= nchunk0 * nchunk1) {
|
||||
break;
|
||||
const int64_t i11 = id % ne11;
|
||||
const int64_t i12 = row_mapping.i2; // row index in src1
|
||||
|
||||
const int64_t i1 = id; // selected expert index
|
||||
const int64_t i2 = i12; // row
|
||||
|
||||
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
|
||||
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
|
||||
// the original src1 data pointer, so we should index using the indices directly
|
||||
// TODO: this is a bit of a hack, we should probably have a better way to handle this
|
||||
const char * src1_col = (const char *) wdata +
|
||||
(src1_cont || src1->type != vec_dot_type
|
||||
? (i11 + i12*ne11)*row_size
|
||||
: (i11*nb11 + i12*nb12));
|
||||
|
||||
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
|
||||
|
||||
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
|
||||
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
|
||||
//}
|
||||
|
||||
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
|
||||
vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
|
||||
}
|
||||
|
||||
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
current_chunk = atomic_fetch_add_explicit(current_chunk_ctr, 1, memory_order_relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
#undef MMID_MATRIX_ROW
|
||||
}
|
||||
|
||||
// ggml_compute_forward_out_prod
|
||||
@@ -13807,19 +13713,14 @@ struct ggml_cplan ggml_graph_plan(
|
||||
cur = 0;
|
||||
const struct ggml_tensor * src0 = node->src[0];
|
||||
const struct ggml_tensor * src1 = node->src[1];
|
||||
const struct ggml_tensor * ids = node->src[2];
|
||||
const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
|
||||
const int n_as = src0->ne[2];
|
||||
// src1
|
||||
if (src1->type != vec_dot_type) {
|
||||
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)) + sizeof(int64_t);
|
||||
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
|
||||
}
|
||||
// matrix_row_counts
|
||||
cur += n_as * sizeof(int64_t) + sizeof(int64_t);
|
||||
// matrix_rows
|
||||
cur += n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping) + sizeof(int64_t);
|
||||
// atomic_current_chunk
|
||||
cur += CACHE_LINE_SIZE*n_as + CACHE_LINE_SIZE;
|
||||
const int n_as = src0->ne[2];
|
||||
cur += GGML_PAD(cur, sizeof(int64_t)); // align
|
||||
cur += n_as * sizeof(int64_t); // matrix_row_counts
|
||||
cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
|
||||
} break;
|
||||
case GGML_OP_OUT_PROD:
|
||||
{
|
||||
|
||||
@@ -280,6 +280,14 @@ template <> inline __m256bh load(const float *p) {
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// CONSTANTS
|
||||
|
||||
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
||||
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
||||
static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// FLOATING POINT MATRIX MULTIPLICATION
|
||||
|
||||
@@ -606,14 +614,6 @@ class tinyBLAS_Q0_AVX {
|
||||
TC *C, int64_t ldc,
|
||||
int ith, int nth)
|
||||
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||
const int8_t kvalues_iq4nl[16] = {
|
||||
-127, -104, -83, -65,
|
||||
-49, -35, -22, -10,
|
||||
1, 13, 25, 38,
|
||||
53, 69, 89, 113
|
||||
};
|
||||
|
||||
iq4nlt = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
|
||||
}
|
||||
|
||||
void matmul(int64_t m, int64_t n) {
|
||||
@@ -1038,7 +1038,6 @@ class tinyBLAS_Q0_AVX {
|
||||
const int64_t ldc;
|
||||
const int ith;
|
||||
const int nth;
|
||||
__m128i iq4nlt;
|
||||
};
|
||||
#endif // __AVX__
|
||||
|
||||
|
||||
@@ -15,9 +15,9 @@ if (CUDAToolkit_FOUND)
|
||||
if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
|
||||
set(CMAKE_CUDA_ARCHITECTURES "native")
|
||||
elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
|
||||
set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;80")
|
||||
set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75")
|
||||
else()
|
||||
set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75;80")
|
||||
set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75")
|
||||
endif()
|
||||
endif()
|
||||
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
||||
|
||||
@@ -178,11 +178,11 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
int major_version = 0;
|
||||
size_t version_length = 0;
|
||||
if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) {
|
||||
std::vector<char> version(version_length+1, '\0');
|
||||
std::string version(version_length, '\0');
|
||||
if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) {
|
||||
version.resize(::strlen(version.data()));
|
||||
version.resize(::strlen(version.c_str()));
|
||||
int parsed_value = 0;
|
||||
if (std::from_chars(version.data(), version.data() + version.size(), parsed_value).ec == std::errc()) {
|
||||
if (std::from_chars(version.c_str(), version.c_str() + version.length(), parsed_value).ec == std::errc()) {
|
||||
major_version = parsed_value;
|
||||
}
|
||||
}
|
||||
@@ -1480,7 +1480,12 @@ static void ggml_cuda_op_mul_mat(
|
||||
const size_t nbytes_data = ggml_nbytes(src0);
|
||||
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
|
||||
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
|
||||
// TODO: remove this for MUSA once the Guilty Lockup issue is resolved
|
||||
#ifndef GGML_USE_MUSA
|
||||
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
|
||||
#else // GGML_USE_MUSA
|
||||
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));
|
||||
#endif // !GGML_USE_MUSA
|
||||
}
|
||||
|
||||
// If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
|
||||
|
||||
@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc) && !GGML_CUDA_CC_IS_GCN(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
@@ -143,7 +143,6 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_rms_norm;
|
||||
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
|
||||
cl_kernel kernel_soft_max, kernel_soft_max_4;
|
||||
cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
|
||||
cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
|
||||
cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
|
||||
cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32;
|
||||
@@ -615,8 +614,6 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
CL_CHECK((backend_ctx->kernel_diag_mask_inf_8 = clCreateKernel(backend_ctx->program, "kernel_diag_mask_inf_8", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_soft_max = clCreateKernel(backend_ctx->program, "kernel_soft_max", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_soft_max_4 = clCreateKernel(backend_ctx->program, "kernel_soft_max_4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_soft_max_f16 = clCreateKernel(backend_ctx->program, "kernel_soft_max_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_soft_max_4_f16 = clCreateKernel(backend_ctx->program, "kernel_soft_max_4_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_rope_norm_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f32", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err));
|
||||
@@ -1047,16 +1044,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
return true;
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
return op->ne[3] == 1;
|
||||
case GGML_OP_ROPE: {
|
||||
const int mode = ((const int32_t *) op->op_params)[2];
|
||||
if (mode & GGML_ROPE_TYPE_MROPE) {
|
||||
return false;
|
||||
}
|
||||
if (mode & GGML_ROPE_TYPE_VISION) {
|
||||
return false;
|
||||
}
|
||||
case GGML_OP_ROPE:
|
||||
return true;
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -3677,8 +3666,6 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||
|
||||
// Local size must be wave size. Each workgroup is a wave, working on a row,
|
||||
// where a row corresponds to leading dimension.
|
||||
int nth = MIN(32, ne00);
|
||||
@@ -3696,17 +3683,9 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
|
||||
cl_kernel kernel;
|
||||
|
||||
if (ne00%4 == 0) {
|
||||
if (use_f16) {
|
||||
kernel = backend_ctx->kernel_soft_max_4_f16;
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_soft_max_4;
|
||||
}
|
||||
kernel = backend_ctx->kernel_soft_max_4;
|
||||
} else {
|
||||
if (use_f16) {
|
||||
kernel = backend_ctx->kernel_soft_max_f16;
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_soft_max;
|
||||
}
|
||||
kernel = backend_ctx->kernel_soft_max;
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
@@ -3787,8 +3766,7 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
const int nb2 = dst ? dst->nb[2] : 0;
|
||||
const int nb3 = dst ? dst->nb[3] : 0;
|
||||
|
||||
GGML_ASSERT(ne10 % ne02 == 0);
|
||||
GGML_ASSERT(ne10 >= ne02);
|
||||
GGML_ASSERT(ne10 == ne02);
|
||||
|
||||
int nth = MIN(64, ne00);
|
||||
|
||||
|
||||
@@ -679,9 +679,6 @@ kernel void kernel_diag_mask_inf_8(
|
||||
//------------------------------------------------------------------------------
|
||||
// softmax
|
||||
//------------------------------------------------------------------------------
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_soft_max(
|
||||
global float * src0,
|
||||
ulong offset0,
|
||||
@@ -814,141 +811,6 @@ kernel void kernel_soft_max_4(
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_soft_max_f16(
|
||||
global float * src0,
|
||||
ulong offset0,
|
||||
global half * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
float scale,
|
||||
float max_bias,
|
||||
float m0,
|
||||
float m1,
|
||||
int n_head_log2
|
||||
) {
|
||||
src0 = (global float *)((global char *)src0 + offset0);
|
||||
src1 = (global half *)((global char *)src1 + offset1);
|
||||
dst = (global float *)((global char *)dst + offsetd);
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
|
||||
global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0;
|
||||
global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
int h = i02;
|
||||
|
||||
float base = h < n_head_log2 ? m0 : m1;
|
||||
int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
}
|
||||
|
||||
// parallel max
|
||||
float lmax = -INFINITY;
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
||||
}
|
||||
float max = sub_group_reduce_max(lmax);
|
||||
|
||||
// parallel sum
|
||||
float lsum = 0.0f;
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max);
|
||||
lsum += exp_psrc0;
|
||||
// Remember the result of exp here. exp is expensive, so we really do not
|
||||
// wish to compute it twice.
|
||||
pdst[i00] = exp_psrc0;
|
||||
}
|
||||
|
||||
const float sum = sub_group_reduce_add(lsum);
|
||||
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
pdst[i00] /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_soft_max_4_f16(
|
||||
global float * src0,
|
||||
ulong offset0,
|
||||
global half * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
float scale,
|
||||
float max_bias,
|
||||
float m0,
|
||||
float m1,
|
||||
int n_head_log2
|
||||
) {
|
||||
src0 = (global float *)((global char *)src0 + offset0);
|
||||
src1 = (global half *)((global char *)src1 + offset1);
|
||||
dst = (global float *)((global char *)dst + offsetd);
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
|
||||
global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0;
|
||||
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
int h = i02;
|
||||
|
||||
float base = h < n_head_log2 ? m0 : m1;
|
||||
int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
}
|
||||
|
||||
// parallel max
|
||||
float4 lmax4 = -INFINITY;
|
||||
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f));
|
||||
}
|
||||
float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3));
|
||||
|
||||
const float max = sub_group_reduce_max(lmax);
|
||||
|
||||
// parallel sum
|
||||
float4 lsum4 = 0.0f;
|
||||
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)) - max);
|
||||
lsum4 += exp_psrc4;
|
||||
pdst4[i00] = exp_psrc4;
|
||||
}
|
||||
float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
|
||||
|
||||
const float sum = sub_group_reduce_add(lsum);
|
||||
|
||||
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
||||
pdst4[i00] /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// kernel_rope
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
@@ -1385,10 +1385,6 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
||||
|
||||
uint32_t lut_size = 0;
|
||||
switch (src0_type) {
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
lut_size = 2*2048;
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
lut_size = 8*256;
|
||||
break;
|
||||
@@ -1434,7 +1430,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
|
||||
|
||||
// some shaders have a minimum subgroup size
|
||||
const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u);
|
||||
const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
|
||||
const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);
|
||||
|
||||
@@ -1497,13 +1492,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
|
||||
const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
|
||||
|
||||
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
|
||||
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
|
||||
l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
|
||||
m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
|
||||
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
|
||||
|
||||
l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
|
||||
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
|
||||
l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
|
||||
m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
|
||||
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
|
||||
|
||||
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
|
||||
m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
|
||||
@@ -1627,8 +1622,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
//CREATE_FA(GGML_TYPE_Q4_K, q4_k)
|
||||
//CREATE_FA(GGML_TYPE_Q5_K, q5_k)
|
||||
//CREATE_FA(GGML_TYPE_Q6_K, q6_k)
|
||||
//CREATE_FA(GGML_TYPE_IQ1_S, iq1_s)
|
||||
//CREATE_FA(GGML_TYPE_IQ1_M, iq1_m)
|
||||
//CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs)
|
||||
//CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs)
|
||||
//CREATE_FA(GGML_TYPE_IQ2_S, iq2_s)
|
||||
@@ -1663,8 +1656,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
@@ -1684,8 +1675,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
@@ -1740,8 +1729,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
@@ -1761,8 +1748,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
@@ -1788,15 +1773,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
} else {
|
||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
@@ -1809,15 +1792,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
}
|
||||
#undef CREATE_MM2
|
||||
#undef CREATE_MM
|
||||
@@ -1860,8 +1841,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
@@ -1885,8 +1864,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
@@ -1928,15 +1905,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
|
||||
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||
@@ -1953,15 +1928,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
#undef CREATE_MM
|
||||
}
|
||||
|
||||
@@ -1991,8 +1964,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
@@ -2013,8 +1984,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
@@ -2036,8 +2005,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
@@ -2058,8 +2025,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_S], "dequant_iq1_s", dequant_iq1_s_len, dequant_iq1_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_M], "dequant_iq1_m", dequant_iq1_m_len, dequant_iq1_m_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XXS], "dequant_iq2_xxs", dequant_iq2_xxs_len, dequant_iq2_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XS], "dequant_iq2_xs", dequant_iq2_xs_len, dequant_iq2_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S], "dequant_iq2_s", dequant_iq2_s_len, dequant_iq2_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||
@@ -2076,8 +2041,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_S], "get_rows_iq1_s", get_rows_iq1_s_len, get_rows_iq1_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_M], "get_rows_iq1_m", get_rows_iq1_m_len, get_rows_iq1_m_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs", get_rows_iq2_xs_len, get_rows_iq2_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S], "get_rows_iq2_s", get_rows_iq2_s_len, get_rows_iq2_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
@@ -2093,8 +2056,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_S], "get_rows_iq1_s_f32", get_rows_iq1_s_f32_len, get_rows_iq1_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_M], "get_rows_iq1_m_f32", get_rows_iq1_m_f32_len, get_rows_iq1_m_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs_f32", get_rows_iq2_xs_f32_len, get_rows_iq2_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S], "get_rows_iq2_s_f32", get_rows_iq2_s_f32_len, get_rows_iq2_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
@@ -3047,8 +3008,6 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
@@ -3103,8 +3062,6 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
@@ -3142,8 +3099,6 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
@@ -3193,8 +3148,6 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
@@ -3227,8 +3180,6 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
@@ -8105,8 +8056,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
@@ -8181,8 +8130,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
//case GGML_TYPE_Q4_K:
|
||||
//case GGML_TYPE_Q5_K:
|
||||
//case GGML_TYPE_Q6_K:
|
||||
//case GGML_TYPE_IQ1_S:
|
||||
//case GGML_TYPE_IQ1_M:
|
||||
//case GGML_TYPE_IQ2_XXS:
|
||||
//case GGML_TYPE_IQ2_XS:
|
||||
//case GGML_TYPE_IQ2_S:
|
||||
@@ -8206,8 +8153,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
|
||||
@@ -12,7 +12,7 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
if (gl_LocalInvocationIndex.x != 0) {
|
||||
return;
|
||||
|
||||
@@ -217,7 +217,7 @@ void quantize(uint dst_idx, uint src_idx)
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
if (gl_LocalInvocationIndex.x != 0) {
|
||||
return;
|
||||
|
||||
@@ -88,83 +88,6 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ1_S)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const uint ib32 = iqs / 32;
|
||||
const uint ib8 = iqs / 8;
|
||||
const int i8 = int(iqs % 8);
|
||||
const uint qh = data_a[a_offset + ib].qh[ib32];
|
||||
const uint qs = data_a[a_offset + ib].qs[ib8];
|
||||
const float dl = float(2 * bitfieldExtract(qh, 12, 3) + 1);
|
||||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
const uint idxhi = bitfieldExtract(qh, 3 * int(ib8 & 3), 3);
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]);
|
||||
// Signed bitfield extract.
|
||||
const ivec2 gvec = ivec2(
|
||||
bitfieldExtract(grid, 2 * (i8), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 1), 2)
|
||||
);
|
||||
return dl * (vec2(gvec) + delta);
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const uint ib32 = iqs / 32;
|
||||
const uint ib8 = iqs / 8;
|
||||
const int i8 = int(iqs % 8);
|
||||
const uint qh = data_a[a_offset + ib].qh[ib32];
|
||||
const uint qs = data_a[a_offset + ib].qs[ib8];
|
||||
const float dl = 2 * bitfieldExtract(qh, 12, 3) + 1;
|
||||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
|
||||
// Signed bitfield extract.
|
||||
const ivec4 gvec = ivec4(
|
||||
bitfieldExtract(grid, 2 * (i8), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 1), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 2), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 3), 2)
|
||||
);
|
||||
return dl * (vec4(gvec) + delta);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ1_M)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const uint ib8 = iqs / 8;
|
||||
const uint ib16 = iqs / 16;
|
||||
const int i8 = int(iqs % 8);
|
||||
const uint sc = data_a[a_offset + ib].scales[iqs / 64];
|
||||
const uint qs = data_a[a_offset + ib].qs[ib8];
|
||||
const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1));
|
||||
const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||
// Signed bitfield extract.
|
||||
const ivec2 gvec = ivec2(
|
||||
bitfieldExtract(grid, 2 * (i8), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 1), 2)
|
||||
);
|
||||
return dl * (vec2(gvec) + delta);
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const uint ib8 = iqs / 8;
|
||||
const uint ib16 = iqs / 16;
|
||||
const int i8 = int(iqs % 8);
|
||||
const uint sc = data_a[a_offset + ib].scales[iqs / 64];
|
||||
const uint qs = data_a[a_offset + ib].qs[ib8];
|
||||
const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1));
|
||||
const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||
// Signed bitfield extract.
|
||||
const ivec4 gvec = ivec4(
|
||||
bitfieldExtract(grid, 2 * (i8), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 1), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 2), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 3), 2)
|
||||
);
|
||||
return dl * (vec4(gvec) + delta);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ2_XXS)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const uint ib32 = iqs / 32;
|
||||
@@ -434,16 +357,7 @@ vec2 get_dm(uint ib, uint a_offset) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ1_M)
|
||||
vec2 get_dm(uint ib, uint a_offset) {
|
||||
const uint16_t[4] scales = data_a[a_offset + ib].scales;
|
||||
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
||||
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
|
||||
return vec2(d, 0);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
vec2 get_dm(uint ib, uint a_offset) {
|
||||
return vec2(float(data_a[a_offset + ib].d), 0);
|
||||
}
|
||||
|
||||
@@ -301,56 +301,6 @@ float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2
|
||||
return ret;
|
||||
}
|
||||
|
||||
#if defined(DATA_A_IQ1_S)
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S {
|
||||
block_iq1_s block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint ib32 = idx / 32;
|
||||
const uint ib8 = idx / 8;
|
||||
|
||||
const uint qh = bl.block.qh[ib32];
|
||||
const uint qs = bl.block.qs[ib8];
|
||||
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
|
||||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)];
|
||||
|
||||
float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta));
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ1_M)
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_M {
|
||||
block_iq1_m block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const u16vec4 scales = u16vec4(bl.block.scales[0], bl.block.scales[1], bl.block.scales[2], bl.block.scales[3]) >> 12;
|
||||
const float16_t d = uint16BitsToHalf(scales.x | (scales.y << 4) | (scales.z << 8) | (scales.w << 12));
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint ib8 = idx / 8;
|
||||
const uint ib16 = idx / 16;
|
||||
const int i8 = int(idx % 8);
|
||||
const uint sc = bl.block.scales[ib8 / 8];
|
||||
const uint qs = bl.block.qs[ib8];
|
||||
const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1));
|
||||
const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
const uint grid = iq1s_grid[qs | ((qh & 7) << 8)];
|
||||
|
||||
float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta));
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ2_XXS)
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS {
|
||||
block_iq2_xxs block;
|
||||
@@ -562,10 +512,6 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
|
||||
#define dequantFuncA dequantFuncQ5_K
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
#define dequantFuncA dequantFuncQ6_K
|
||||
#elif defined(DATA_A_IQ1_S)
|
||||
#define dequantFuncA dequantFuncIQ1_S
|
||||
#elif defined(DATA_A_IQ1_M)
|
||||
#define dequantFuncA dequantFuncIQ1_M
|
||||
#elif defined(DATA_A_IQ2_XXS)
|
||||
#define dequantFuncA dequantFuncIQ2_XXS
|
||||
#elif defined(DATA_A_IQ2_XS)
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
|
||||
#include "dequant_head.comp"
|
||||
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {block_iq1_m data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
|
||||
void main() {
|
||||
// Each thread handles 1 subblock (32 values with 2 scales)
|
||||
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
|
||||
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
|
||||
if (ib >= p.nel / 256) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint ib32 = gl_LocalInvocationID.x % 8;
|
||||
const uint ib64 = ib32 / 2;
|
||||
const uint b_idx = 256 * ib + 32 * ib32;
|
||||
|
||||
const uint16_t[4] scales = data_a[ib].scales;
|
||||
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
||||
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
|
||||
|
||||
const uint sc = data_a[ib].scales[ib64];
|
||||
[[unroll]] for (int l = 0; l < 4; ++l) {
|
||||
const uint ib16 = 2 * ib32 + l / 2;
|
||||
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
|
||||
const uint qh = data_a[ib].qh[ib16] >> (4 * (l & 1));
|
||||
const uint qs = data_a[ib].qs[4 * ib32 + l];
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||
[[unroll]] for (int j = 0; j < 8; ++j) {
|
||||
data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "dequant_head.comp"
|
||||
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {block_iq1_s data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
|
||||
void main() {
|
||||
// Each thread handles 1 subblock (32 values with 2 scales)
|
||||
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
|
||||
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
|
||||
if (ib >= p.nel / 256) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint ib32 = gl_LocalInvocationID.x % 8;
|
||||
const uint b_idx = 256 * ib + 32 * ib32;
|
||||
|
||||
uint qh = data_a[ib].qh[ib32];
|
||||
const float d = float(data_a[ib].d);
|
||||
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
|
||||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||
const uint qs = data_a[ib].qs[4 * ib32 + l];
|
||||
const uint hi = bitfieldExtract(qh, 3 * int(l), 3);
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | (hi << 8)]);
|
||||
[[unroll]] for (int j = 0; j < 8; ++j) {
|
||||
data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -104,7 +104,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ void main() {
|
||||
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
|
||||
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
|
||||
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
|
||||
@@ -133,7 +133,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
void main() {
|
||||
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
|
||||
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
#version 450
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#include "mul_mat_vec_base.comp"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||
|
||||
void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
|
||||
const uint y_idx = i * QUANT_K + 32 * ib32;
|
||||
|
||||
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint16_t[4] scales = data_a[ibi].scales;
|
||||
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
||||
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
|
||||
|
||||
const uint sc = data_a[ibi].scales[ib32 / 2] >> (6 * (ib32 & 1));
|
||||
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||
const uint qh = data_a[ibi].qh[2 * ib32 + l / 2] >> (4 * (l&1));
|
||||
const uint qs = data_a[ibi].qs[4 * ib32 + l];
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1);
|
||||
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
|
||||
vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
||||
[[unroll]] for (int k = 0; k < 4; ++k) {
|
||||
sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta,
|
||||
fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum));
|
||||
}
|
||||
temp[j][n] = fma(dl, sum, temp[j][n]);
|
||||
}
|
||||
}
|
||||
ibi += num_blocks_per_row;
|
||||
}
|
||||
}
|
||||
|
||||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
uint a_offset, b_offset, d_offset;
|
||||
get_offsets(a_offset, b_offset, d_offset);
|
||||
|
||||
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
||||
|
||||
// 8 threads are used to process each block
|
||||
const uint blocks_per_wg = gl_WorkGroupSize.x/8;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint itid = tid % 8; // 0...7
|
||||
const uint ix = tid / 8;
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
||||
temp[j][i] = FLOAT_TYPE(0);
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
|
||||
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
|
||||
|
||||
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
|
||||
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
|
||||
// do NUM_ROWS at a time, unless there aren't enough remaining rows
|
||||
if (first_row + NUM_ROWS <= p.stride_d) {
|
||||
compute_outputs(first_row, NUM_ROWS);
|
||||
} else {
|
||||
if (first_row >= p.stride_d) {
|
||||
return;
|
||||
}
|
||||
compute_outputs(first_row, p.stride_d - first_row);
|
||||
}
|
||||
}
|
||||
@@ -1,79 +0,0 @@
|
||||
#version 450
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#include "mul_mat_vec_base.comp"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||
|
||||
void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
|
||||
const uint y_idx = i * QUANT_K + 32 * ib32;
|
||||
|
||||
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const float d = float(data_a[ibi].d);
|
||||
const uint qh = data_a[ibi].qh[ib32];
|
||||
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
|
||||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
|
||||
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||
const uint qs = data_a[ibi].qs[4 * ib32 + l];
|
||||
const uint idxhi = bitfieldExtract(qh, 3 * int(l), 3);
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]);
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
|
||||
vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
||||
[[unroll]] for (int k = 0; k < 4; ++k) {
|
||||
sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta,
|
||||
fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum));
|
||||
}
|
||||
temp[j][n] = fma(dl, sum, temp[j][n]);
|
||||
}
|
||||
}
|
||||
ibi += num_blocks_per_row;
|
||||
}
|
||||
}
|
||||
|
||||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
uint a_offset, b_offset, d_offset;
|
||||
get_offsets(a_offset, b_offset, d_offset);
|
||||
|
||||
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
||||
|
||||
// 8 threads are used to process each block
|
||||
const uint blocks_per_wg = gl_WorkGroupSize.x/8;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint itid = tid % 8; // 0...7
|
||||
const uint ix = tid / 8;
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
||||
temp[j][i] = FLOAT_TYPE(0);
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
|
||||
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
|
||||
|
||||
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
|
||||
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
|
||||
// do NUM_ROWS at a time, unless there aren't enough remaining rows
|
||||
if (first_row + NUM_ROWS <= p.stride_d) {
|
||||
compute_outputs(first_row, NUM_ROWS);
|
||||
} else {
|
||||
if (first_row >= p.stride_d) {
|
||||
return;
|
||||
}
|
||||
compute_outputs(first_row, p.stride_d - first_row);
|
||||
}
|
||||
}
|
||||
@@ -6,9 +6,6 @@
|
||||
#ifdef FLOAT16
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#endif
|
||||
#if defined(DATA_A_IQ1_M)
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
@@ -98,7 +95,7 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
@@ -440,56 +437,6 @@ void main() {
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
|
||||
#elif defined(DATA_A_IQ1_S)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
const uint ib8 = (idx % 128) / 4;
|
||||
const int i8 = 2 * int(idx % 4);
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qh = data_a[ib].qh[ib32];
|
||||
const uint qs = data_a[ib].qs[ib8];
|
||||
const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1);
|
||||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
|
||||
|
||||
const ivec2 gvec = ivec2(
|
||||
bitfieldExtract(grid, 2 * (i8), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 1), 2)
|
||||
);
|
||||
const vec2 v = dl * (vec2(gvec) + delta);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
#elif defined(DATA_A_IQ1_M)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib8 = (idx % 128) / 4;
|
||||
const uint ib16 = ib8 / 2;
|
||||
const int i8 = 2 * int(idx % 4);
|
||||
|
||||
const uint16_t[4] scales = data_a[ib].scales;
|
||||
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
||||
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
|
||||
const uint sc = scales[ib8 / 8];
|
||||
const uint qs = data_a[ib].qs[ib8];
|
||||
const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1));
|
||||
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||
const ivec2 gvec = ivec2(
|
||||
bitfieldExtract(grid, 2 * (i8), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 1), 2)
|
||||
);
|
||||
const vec2 v = dl * (vec2(gvec) + delta);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
#elif defined(DATA_A_IQ2_XXS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
@@ -106,7 +106,7 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
|
||||
@@ -294,187 +294,6 @@ struct block_q6_K_packed16
|
||||
|
||||
// IQuants
|
||||
|
||||
#define QUANT_K_IQ1_S 256
|
||||
#define QUANT_R_IQ1_S 1
|
||||
|
||||
struct block_iq1_s {
|
||||
float16_t d;
|
||||
uint8_t qs[QUANT_K_IQ1_S/8];
|
||||
uint16_t qh[QUANT_K_IQ1_S/32];
|
||||
};
|
||||
|
||||
#define QUANT_K_IQ1_M 256
|
||||
#define QUANT_R_IQ1_M 1
|
||||
|
||||
struct block_iq1_m {
|
||||
uint8_t qs[QUANT_K_IQ1_M/8];
|
||||
uint8_t qh[QUANT_K_IQ1_M/16];
|
||||
uint16_t scales[QUANT_K_IQ1_M/64];
|
||||
};
|
||||
|
||||
#if defined(DATA_A_IQ1_S)
|
||||
#define QUANT_K QUANT_K_IQ1_S
|
||||
#define QUANT_R QUANT_R_IQ1_S
|
||||
#define A_TYPE block_iq1_s
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ1_M)
|
||||
#define QUANT_K QUANT_K_IQ1_M
|
||||
#define QUANT_R QUANT_R_IQ1_M
|
||||
#define A_TYPE block_iq1_m
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M)
|
||||
#define IQ1S_DELTA 0.125f
|
||||
#define IQ1M_DELTA 0.125f
|
||||
|
||||
// Packed IQ1S grid where every 2 vec8 are encoded on 32 bits (2 bits per coordinate).
|
||||
const uint[1024] iq1s_grid_const = {
|
||||
0xfffdffff, 0xfff7fff0, 0xffccfff5, 0xffdfffc0, 0xffd7ffdd, 0xff30ffd5, 0xff03ff0c, 0xff10ff01,
|
||||
0xff7dff7f, 0xff75ff77, 0xff5fff40, 0xff57ff5d, 0xfcf3ff55, 0xfcccfcf0, 0xfcc1fcc3, 0xfcc5fcc4,
|
||||
0xfc3cfcd0, 0xfc34fc31, 0xfc00fc0d, 0xfc1cfc05, 0xfc11fc13, 0xfc70fc17, 0xfc43fc4c, 0xfc50fc41,
|
||||
0xfdfdfdff, 0xfdf5fdf7, 0xfddffdc0, 0xfdd7fddd, 0xfd30fdd5, 0xfd04fd0c, 0xfd14fd13, 0xfd7dfd7f,
|
||||
0xfd75fd77, 0xfd40fd4c, 0xfd5ffd44, 0xfd57fd5d, 0xf3ccfd55, 0xf3c1f3c3, 0xf33cf3d0, 0xf300f334,
|
||||
0xf313f305, 0xf34cf310, 0xf350f344, 0xf0f3f0fc, 0xf0f1f0f0, 0xf0c7f0c0, 0xf0d4f0c5, 0xf030f03f,
|
||||
0xf00ff035, 0xf003f00c, 0xf001f000, 0xf01ff004, 0xf010f01d, 0xf015f017, 0xf04cf07c, 0xf047f040,
|
||||
0xf05cf045, 0xf050f053, 0xf054f051, 0xf1c4f1c3, 0xf133f13c, 0xf10df10f, 0xf107f100, 0xf11cf11f,
|
||||
0xf114f111, 0xf14cf170, 0xf144f143, 0xf7fdf7ff, 0xf7f5f7f7, 0xf7dff7c0, 0xf7d7f7dd, 0xf730f7d5,
|
||||
0xf701f70c, 0xf77ff710, 0xf777f77d, 0xf740f775, 0xf75df75f, 0xf755f757, 0xf4ccf4f0, 0xf4c4f4c3,
|
||||
0xf4d0f4d3, 0xf40ff43c, 0xf400f40c, 0xf413f41c, 0xf44cf414, 0xf441f443, 0xf450f444, 0xf5fdf5ff,
|
||||
0xf5f5f5f7, 0xf5dff5c0, 0xf5d7f5dd, 0xf530f5d5, 0xf504f50c, 0xf510f51c, 0xf57df57f, 0xf577f570,
|
||||
0xf540f575, 0xf55df55f, 0xf555f557, 0xcfcccfcf, 0xcfc4cfc3, 0xcfd0cfd3, 0xcf33cf3c, 0xcf00cf0f,
|
||||
0xcf1ccf07, 0xcf10cf13, 0xcf4ccf14, 0xcf41cf43, 0xcf50cf5c, 0xccf3ccfc, 0xccf4ccf1, 0xcccdcccf,
|
||||
0xccc7ccc0, 0xccd3ccdc, 0xcc30ccd4, 0xcc0fcc35, 0xcc0dcc0c, 0xcc00cc03, 0xcc04cc01, 0xcc10cc1f,
|
||||
0xcc4dcc73, 0xcc5ccc40, 0xcdcccc53, 0xcdc1cdc3, 0xcd3fcdd0, 0xcd34cd31, 0xcd00cd0d, 0xcd05cd07,
|
||||
0xcd11cd13, 0xcd4ccd70, 0xcd41cd43, 0xc3fccd50, 0xc3f4c3f1, 0xc3c0c3c3, 0xc3c4c3c7, 0xc3d1c3dc,
|
||||
0xc330c33c, 0xc337c331, 0xc30cc335, 0xc300c303, 0xc304c301, 0xc310c31d, 0xc373c317, 0xc34fc374,
|
||||
0xc340c343, 0xc344c347, 0xc35cc345, 0xc350c353, 0xc0fdc354, 0xc0f5c0f0, 0xc0c3c0cc, 0xc0c1c0c0,
|
||||
0xc0dfc0c4, 0xc0d0c0dd, 0xc0d5c0d7, 0xc033c03c, 0xc031c030, 0xc00dc00c, 0xc000c003, 0xc004c001,
|
||||
0xc01cc005, 0xc010c013, 0xc014c011, 0xc07dc07f, 0xc070c073, 0xc075c077, 0xc04cc04f, 0xc040c043,
|
||||
0xc044c041, 0xc05fc045, 0xc050c05d, 0xc1f3c1fc, 0xc1f1c1f0, 0xc1c1c1c0, 0xc1c5c1c7, 0xc1d1c1dc,
|
||||
0xc13dc13f, 0xc130c133, 0xc135c137, 0xc100c10c, 0xc107c101, 0xc11cc104, 0xc110c113, 0xc114c117,
|
||||
0xc171c115, 0xc14dc175, 0xc153c140, 0xc7ccc154, 0xc7d0c7c1, 0xc733c73c, 0xc734c731, 0xc700c70f,
|
||||
0xc705c707, 0xc71cc71f, 0xc711c713, 0xc770c714, 0xc743c74c, 0xc4cfc750, 0xc4c0c4cd, 0xc4dcc4c5,
|
||||
0xc43dc4d0, 0xc430c433, 0xc40cc437, 0xc400c403, 0xc404c401, 0xc41fc405, 0xc415c410, 0xc44cc474,
|
||||
0xc440c44d, 0xc45cc447, 0xc454c451, 0xc5c1c5f4, 0xc5d1c5d3, 0xc531c533, 0xc50fc534, 0xc500c50d,
|
||||
0xc51cc507, 0xc514c511, 0xc54cc570, 0xc545c541, 0xdffddfff, 0xdff5dff7, 0xdfdfdfc0, 0xdfd0dfdd,
|
||||
0xdfd5dfd7, 0xdf0cdf30, 0xdf1cdf04, 0xdf7fdf10, 0xdf77df7d, 0xdf40df75, 0xdf5ddf5f, 0xdf57df50,
|
||||
0xdcf0df55, 0xdcc3dccc, 0xdcd0dcc4, 0xdc33dc3d, 0xdc00dc34, 0xdc05dc07, 0xdc13dc1c, 0xdc11dc10,
|
||||
0xdc4fdc70, 0xdc44dc41, 0xddfcdc50, 0xddf5ddf7, 0xddc0ddcc, 0xdddddddf, 0xddd5ddd7, 0xdd0cdd30,
|
||||
0xdd04dd01, 0xdd7cdd10, 0xdd75dd77, 0xdd40dd4c, 0xdd5ddd5f, 0xdd55dd57, 0xd3c3d3f0, 0xd3c4d3c1,
|
||||
0xd333d3d0, 0xd331d330, 0xd30dd334, 0xd307d300, 0xd311d305, 0xd34cd370, 0xd344d343, 0xd350d35c,
|
||||
0xd0c0d0f4, 0xd0d4d0dc, 0xd030d03f, 0xd00cd037, 0xd000d003, 0xd01dd004, 0xd017d010, 0xd04fd074,
|
||||
0xd040d043, 0xd045d047, 0xd053d05c, 0xd054d051, 0xd1cfd1f0, 0xd1c4d1cd, 0xd13cd1d0, 0xd100d134,
|
||||
0xd11cd11f, 0xd173d114, 0xd14fd171, 0xd7ffd145, 0xd7f7d7fd, 0xd7c0d7f5, 0xd7ddd7df, 0xd7d5d7d7,
|
||||
0xd70cd730, 0xd710d703, 0xd77dd77f, 0xd775d777, 0xd75dd75f, 0xd755d757, 0xd4ccd4f4, 0xd4c4d4c3,
|
||||
0xd431d4d0, 0xd40dd434, 0xd41cd400, 0xd411d413, 0xd470d414, 0xd441d44f, 0xd453d444, 0xd5ffd450,
|
||||
0xd5f7d5fd, 0xd5dfd5f5, 0xd5d7d5dd, 0xd530d5d5, 0xd501d50c, 0xd510d504, 0xd57dd57f, 0xd575d577,
|
||||
0xd55fd540, 0xd557d55d, 0x3ff0d555, 0x3fc13fcc, 0x3f343fd0, 0x3f003f0d, 0x3f053f07, 0x3f133f1c,
|
||||
0x3f433f11, 0x3f5c3f44, 0x3cff3f51, 0x3cf33cfc, 0x3cf43cf1, 0x3cc03ccd, 0x3cc73cc1, 0x3cdc3cc5,
|
||||
0x3cd43cd1, 0x3c373c30, 0x3c0c3c35, 0x3c003c03, 0x3c043c01, 0x3c103c05, 0x3c153c17, 0x3c733c7c,
|
||||
0x3c4f3c71, 0x3c403c4d, 0x3c5c3c5f, 0x3df03c5d, 0x3dc33dcc, 0x3dd03dc1, 0x3d0d3d3c, 0x3d053d00,
|
||||
0x3d143d13, 0x3d433d74, 0x33fc3d50, 0x33c433c0, 0x333033d4, 0x33353337, 0x3303330c, 0x33013300,
|
||||
0x331d331c, 0x33173310, 0x337c3315, 0x33743371, 0x334d334f, 0x335f3340, 0x3354335c, 0x30fd30fc,
|
||||
0x30f530f0, 0x30c330cc, 0x30c130c0, 0x30df30c4, 0x30d530d0, 0x3033303c, 0x30313030, 0x300f3034,
|
||||
0x3003300c, 0x30013000, 0x30043007, 0x3013301c, 0x30113010, 0x307d3014, 0x30703073, 0x304c3077,
|
||||
0x30403043, 0x30443041, 0x30503045, 0x30553057, 0x31f031fc, 0x31c331f4, 0x31c731c0, 0x31dc31c5,
|
||||
0x31d431d3, 0x313d313f, 0x31373130, 0x310c310f, 0x3100310d, 0x31043101, 0x3110311d, 0x317c3117,
|
||||
0x31753170, 0x31403143, 0x3153315c, 0x37f03151, 0x37c037cc, 0x37d037c5, 0x3734373d, 0x3700370f,
|
||||
0x371c3707, 0x37113713, 0x37703714, 0x3743374c, 0x37443741, 0x34fc3750, 0x34f134f0, 0x34cf34f5,
|
||||
0x34c034c3, 0x34dc34c7, 0x34d134d3, 0x3430343f, 0x340c3435, 0x3403340d, 0x34013400, 0x341f3404,
|
||||
0x3410341d, 0x34153411, 0x34743471, 0x3440344d, 0x34473441, 0x3453345c, 0x34543451, 0x353335c1,
|
||||
0x35343531, 0x35073500, 0x35133505, 0x35433514, 0x0ffc3550, 0x0ff00ff3, 0x0ff40ff1, 0x0fc00fcd,
|
||||
0x0fdc0fc5, 0x0fd40fd3, 0x0f300f3f, 0x0f0c0f37, 0x0f000f03, 0x0f040f01, 0x0f170f10, 0x0f740f71,
|
||||
0x0f470f40, 0x0f5c0f5f, 0x0f540f51, 0x0cf70cf0, 0x0cf50cf4, 0x0cc30ccc, 0x0cc10cc0, 0x0cc40cc7,
|
||||
0x0cd00cdf, 0x0cd70cd1, 0x0c3c0cd5, 0x0c300c33, 0x0c340c31, 0x0c0c0c0f, 0x0c030c0d, 0x0c010c00,
|
||||
0x0c040c07, 0x0c1c0c05, 0x0c100c13, 0x0c140c11, 0x0c700c7d, 0x0c430c4c, 0x0c410c40, 0x0c5f0c44,
|
||||
0x0c550c50, 0x0df10dfc, 0x0dc00dcd, 0x0ddc0dc5, 0x0d3d0dd3, 0x0d350d30, 0x0d030d0c, 0x0d010d00,
|
||||
0x0d1d0d04, 0x0d700d10, 0x0d4d0d4f, 0x0d440d40, 0x0d530d45, 0x03f003f3, 0x03c303cc, 0x03c103c0,
|
||||
0x03c403c7, 0x03d003dc, 0x03d503d7, 0x0333033c, 0x03310330, 0x03350334, 0x030c030f, 0x03000303,
|
||||
0x03070301, 0x03050304, 0x031d031c, 0x03100313, 0x03140311, 0x0377037f, 0x034c0375, 0x03400343,
|
||||
0x03440341, 0x0353035c, 0x03550350, 0x00fd00fc, 0x00f000f3, 0x00f400f1, 0x00cc00cf, 0x00c300cd,
|
||||
0x00c100c0, 0x00c500c4, 0x00d300dc, 0x00d100d0, 0x003f00d4, 0x003d003c, 0x00300033, 0x00370031,
|
||||
0x000f0034, 0x000d000c, 0x00000003, 0x00070001, 0x00050004, 0x001c001f, 0x00100013, 0x00170011,
|
||||
0x00150014, 0x0073007c, 0x00740070, 0x004f0075, 0x0043004c, 0x00410040, 0x00440047, 0x0053005c,
|
||||
0x00510050, 0x01ff0054, 0x01fd01fc, 0x01f101f3, 0x01f401f7, 0x01c301cc, 0x01c701c0, 0x01df01c4,
|
||||
0x01dd01dc, 0x01d001d3, 0x01d701d1, 0x013c01d4, 0x01310130, 0x01340137, 0x010f0135, 0x010d010c,
|
||||
0x01000103, 0x01070101, 0x01050104, 0x0113011c, 0x01140110, 0x0170017d, 0x01770171, 0x01750174,
|
||||
0x0140014c, 0x015d0145, 0x01510150, 0x01540157, 0x07f007f3, 0x07f407f1, 0x07c007cf, 0x07dc07c7,
|
||||
0x073007d5, 0x07350737, 0x0703070c, 0x07010700, 0x07040707, 0x071d071f, 0x07100713, 0x0774077d,
|
||||
0x074d074f, 0x07470740, 0x0754075c, 0x04fd04fc, 0x04f504f0, 0x04c304cc, 0x04c104c0, 0x04d004c4,
|
||||
0x0433043c, 0x04310430, 0x040f0434, 0x040d040c, 0x04000403, 0x04070401, 0x04050404, 0x0413041c,
|
||||
0x04110410, 0x047c0414, 0x04740470, 0x0443044c, 0x04410440, 0x04440447, 0x05f30450, 0x05c005f7,
|
||||
0x05df05c5, 0x05d105d0, 0x053005d4, 0x05340537, 0x0500050c, 0x05070501, 0x051d0504, 0x05170510,
|
||||
0x057c0515, 0x054d0575, 0x05410540, 0x05450547, 0x1ff0055c, 0x1fc11fc3, 0x1fd01fc4, 0x1f0f1f33,
|
||||
0x1f011f00, 0x1f051f07, 0x1f131f1c, 0x1f141f11, 0x1f411f7c, 0x1cfc1f50, 0x1cf11cf3, 0x1ccd1cf4,
|
||||
0x1cdc1cc0, 0x1cd11cdd, 0x1c301cd4, 0x1c0c1c34, 0x1c011c00, 0x1c101c04, 0x1c151c11, 0x1c751c73,
|
||||
0x1c401c4d, 0x1c511c5c, 0x1dcc1c54, 0x1dc41dc1, 0x1d3c1d3f, 0x1d001d31, 0x1d071d01, 0x1d701d1f,
|
||||
0x1d411d4c, 0x13cc1d50, 0x13c013cd, 0x13c513c1, 0x13d113dc, 0x133f13d4, 0x1330133d, 0x13351337,
|
||||
0x1303130c, 0x13011300, 0x13051304, 0x131d131f, 0x13731310, 0x13741370, 0x134d134f, 0x13401343,
|
||||
0x13471341, 0x135c1345, 0x13541353, 0x10f710f0, 0x10cc10f5, 0x10c110c0, 0x103310c4, 0x10311030,
|
||||
0x100f1034, 0x1003100c, 0x10011000, 0x101c1004, 0x10101013, 0x10141011, 0x10741071, 0x104c1075,
|
||||
0x10411040, 0x10451044, 0x1050105d, 0x10571051, 0x11f411fd, 0x11df11c0, 0x11d711d1, 0x113f11d4,
|
||||
0x11371130, 0x110c1135, 0x11001103, 0x11071101, 0x111f1105, 0x11171110, 0x117d117f, 0x11751170,
|
||||
0x11411143, 0x11441147, 0x1153115f, 0x11551151, 0x17c417c1, 0x173c17d0, 0x1700170d, 0x171c1705,
|
||||
0x17701714, 0x1747174c, 0x14fc1751, 0x14cf14f3, 0x14dc14c0, 0x14d114d3, 0x143f14d4, 0x1430143c,
|
||||
0x14371431, 0x1403140c, 0x14011400, 0x141f1404, 0x14151410, 0x1473147d, 0x14401475, 0x1453145c,
|
||||
0x14541450, 0x15c115cc, 0x153c15c7, 0x15341533, 0x1500150f, 0x15051507, 0x15101513, 0x15711514,
|
||||
0x15471543, 0x15511545, 0x7ffd7fff, 0x7ff57ff7, 0x7fdd7fdf, 0x7fd57fd7, 0x7f0f7f30, 0x7f037f0c,
|
||||
0x7f047f01, 0x7f7f7f10, 0x7f777f7d, 0x7f407f75, 0x7f5d7f5f, 0x7f557f57, 0x7ccc7cf0, 0x7cc17cc3,
|
||||
0x7cd07cc4, 0x7c337c3c, 0x7c0f7c34, 0x7c007c0d, 0x7c077c01, 0x7c137c04, 0x7c147c11, 0x7c747c70,
|
||||
0x7c417c43, 0x7c507c44, 0x7dfd7dff, 0x7df57df7, 0x7ddf7dc0, 0x7dd77ddd, 0x7d0c7dd5, 0x7d047d03,
|
||||
0x7d7f7d10, 0x7d777d7d, 0x7d407d75, 0x7d5d7d5f, 0x7d557d57, 0x73c473c3, 0x7333733c, 0x7300730c,
|
||||
0x731c7305, 0x73147313, 0x73447343, 0x70f470fc, 0x70c070cd, 0x70d170c5, 0x703f70d4, 0x7030703c,
|
||||
0x700c7037, 0x70007003, 0x70047001, 0x70107005, 0x70177011, 0x707c7015, 0x70717073, 0x704f7074,
|
||||
0x7040704d, 0x70517047, 0x71c171cc, 0x71d071c4, 0x7133713c, 0x71357134, 0x7100710f, 0x71057104,
|
||||
0x7111711c, 0x71707115, 0x7145714c, 0x77ff7153, 0x77f777fd, 0x77c077f5, 0x77dd77df, 0x77d577d7,
|
||||
0x7730773c, 0x7703770c, 0x77107704, 0x777f7714, 0x7777777d, 0x77407775, 0x775d775f, 0x77557757,
|
||||
0x74f174f0, 0x74c374cc, 0x74d074c1, 0x7433743c, 0x74347431, 0x740d740f, 0x74057400, 0x7413741c,
|
||||
0x74417470, 0x74507444, 0x75fd75ff, 0x75f575f7, 0x75df75c0, 0x75d775dd, 0x753075d5, 0x7503750c,
|
||||
0x757f7501, 0x7577757d, 0x75407575, 0x755d755f, 0x75557557, 0x4fcc4ff0, 0x4fc74fc1, 0x4fd04fc4,
|
||||
0x4f314f3c, 0x4f004f34, 0x4f054f07, 0x4f154f14, 0x4f4c4f70, 0x4f414f43, 0x4f504f44, 0x4cf34cfc,
|
||||
0x4cf44cf1, 0x4cc04ccf, 0x4cc54cc7, 0x4cd34cdc, 0x4cd44cd1, 0x4c304c3f, 0x4c0c4c0f, 0x4c004c03,
|
||||
0x4c044c01, 0x4c104c1d, 0x4c714c73, 0x4c404c4d, 0x4c5c4c47, 0x4c514c53, 0x4df04c54, 0x4dc34dcc,
|
||||
0x4dd04dc4, 0x4d314d33, 0x4d0f4d34, 0x4d004d0d, 0x4d114d07, 0x4d704d14, 0x4d414d43, 0x43fc4d54,
|
||||
0x43f143f3, 0x43c043cf, 0x43d143c7, 0x4335433f, 0x4303430c, 0x43014300, 0x43044307, 0x431c431f,
|
||||
0x4310431d, 0x43714373, 0x4343434d, 0x43474340, 0x4354435c, 0x40f040ff, 0x40f540f7, 0x40cc40cf,
|
||||
0x40c040c3, 0x40c440c1, 0x40d040dc, 0x40d540d4, 0x4033403c, 0x40314030, 0x400f4034, 0x400d400c,
|
||||
0x40004003, 0x40074001, 0x40054004, 0x4013401c, 0x40114010, 0x407c4014, 0x40774070, 0x404d404c,
|
||||
0x40404043, 0x40444041, 0x405f4045, 0x4050405d, 0x40554057, 0x41f341fc, 0x41c041cf, 0x41df41c4,
|
||||
0x41d441d1, 0x41374130, 0x410c4134, 0x4100410d, 0x41044101, 0x41174110, 0x4173417d, 0x41754174,
|
||||
0x4143414d, 0x41534140, 0x41544151, 0x47c147f0, 0x47d047c4, 0x4731473c, 0x470d470f, 0x47014700,
|
||||
0x47134705, 0x47704710, 0x4741474c, 0x47504744, 0x44f144f3, 0x44cf44f4, 0x44c044cd, 0x44c544c7,
|
||||
0x44dc44df, 0x44d144d3, 0x443d443f, 0x44374430, 0x440c4435, 0x44004403, 0x44044401, 0x4410441d,
|
||||
0x44154411, 0x4473447c, 0x444d444f, 0x44454440, 0x4451445c, 0x45c045f0, 0x453345d0, 0x45344531,
|
||||
0x4500450f, 0x451c4507, 0x454c4570, 0x45404543, 0x5fff4541, 0x5ff75ffd, 0x5fc05ff5, 0x5fdd5fdf,
|
||||
0x5fd55fd7, 0x5f0c5f30, 0x5f015f03, 0x5f7f5f04, 0x5f775f7d, 0x5f405f75, 0x5f5d5f5f, 0x5f555f57,
|
||||
0x5cf45cf0, 0x5cc35ccc, 0x5cc45cc1, 0x5c315cc5, 0x5c0c5c34, 0x5c075c00, 0x5c1c5c05, 0x5c705c13,
|
||||
0x5c4d5c4f, 0x5c445c41, 0x5df75dfd, 0x5dcf5df5, 0x5ddd5dc4, 0x5dd55dd7, 0x5d0c5d30, 0x5d045d01,
|
||||
0x5d7f5d10, 0x5d775d7d, 0x5d405d75, 0x5d5d5d5f, 0x5d555d57, 0x53d053c4, 0x5333533c, 0x5303530f,
|
||||
0x53075300, 0x531c5305, 0x53115310, 0x53145317, 0x50f15370, 0x50cf50f4, 0x50c050cd, 0x50d150c7,
|
||||
0x503d50d4, 0x500c5030, 0x50005003, 0x50045001, 0x50155010, 0x5073507c, 0x50715070, 0x504d5074,
|
||||
0x50475040, 0x51cc51f0, 0x51c551c1, 0x51d051dc, 0x51315133, 0x510d5135, 0x51015100, 0x511f5107,
|
||||
0x5171511d, 0x5140514f, 0x51445141, 0x5153515c, 0x57ff5151, 0x57f757fd, 0x57df57f5, 0x57d757dd,
|
||||
0x570c57d5, 0x57015703, 0x577f5704, 0x5777577d, 0x57405775, 0x575d575f, 0x57555757, 0x54c354f0,
|
||||
0x54dc54c4, 0x543c54d0, 0x5400540f, 0x541c5405, 0x54145411, 0x5441544f, 0x55fd55ff, 0x55f555f7,
|
||||
0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
|
||||
};
|
||||
|
||||
shared uint16_t iq1s_grid[2048];
|
||||
|
||||
#define NEEDS_INIT_IQ_SHMEM
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
for (uint i = gl_LocalInvocationIndex.x; i < iq1s_grid_const.length(); i += wgsize.x) {
|
||||
u16vec2 g = unpack16(iq1s_grid_const[i]);
|
||||
iq1s_grid[2*i+0] = g.x;
|
||||
iq1s_grid[2*i+1] = g.y;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
#endif
|
||||
|
||||
#define QUANT_K_IQ2_XXS 256
|
||||
#define QUANT_R_IQ2_XXS 1
|
||||
|
||||
@@ -561,7 +380,6 @@ const uvec2[256] iq2xxs_grid_const = {
|
||||
|
||||
shared uvec2 iq2xxs_grid[256];
|
||||
|
||||
#define NEEDS_INIT_IQ_SHMEM
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
@@ -729,7 +547,6 @@ const uvec2 iq2xs_grid_const[512] = {
|
||||
|
||||
shared uvec2 iq2xs_grid[512];
|
||||
|
||||
#define NEEDS_INIT_IQ_SHMEM
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
@@ -1019,7 +836,6 @@ const uvec2 iq2s_grid_const[1024] = {
|
||||
|
||||
shared uvec2 iq2s_grid[1024];
|
||||
|
||||
#define NEEDS_INIT_IQ_SHMEM
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
@@ -1088,7 +904,6 @@ const uint32_t iq3xxs_grid_const[256] = {
|
||||
|
||||
shared uint32_t iq3xxs_grid[256];
|
||||
|
||||
#define NEEDS_INIT_IQ_SHMEM
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
@@ -1196,7 +1011,6 @@ const uint32_t iq3s_grid_const[512] = {
|
||||
|
||||
shared uint32_t iq3s_grid[512];
|
||||
|
||||
#define NEEDS_INIT_IQ_SHMEM
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
@@ -1259,7 +1073,6 @@ const int8_t kvalues_iq4nl_const[16] = {
|
||||
|
||||
shared FLOAT_TYPE kvalues_iq4nl[16];
|
||||
|
||||
#define NEEDS_INIT_IQ_SHMEM
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
|
||||
@@ -55,8 +55,6 @@ const std::vector<std::string> type_names = {
|
||||
"q4_k",
|
||||
"q5_k",
|
||||
"q6_k",
|
||||
"iq1_s",
|
||||
"iq1_m",
|
||||
"iq2_xxs",
|
||||
"iq2_xs",
|
||||
"iq2_s",
|
||||
@@ -184,13 +182,6 @@ std::string to_uppercase(const std::string& input) {
|
||||
return result;
|
||||
}
|
||||
|
||||
bool string_starts_with(const std::string& str, const std::string& prefix) {
|
||||
if (prefix.size() > str.size()) {
|
||||
return false;
|
||||
}
|
||||
return std::equal(prefix.begin(), prefix.end(), str.begin());
|
||||
}
|
||||
|
||||
bool string_ends_with(const std::string& str, const std::string& suffix) {
|
||||
if (suffix.size() > str.size()) {
|
||||
return false;
|
||||
@@ -396,7 +387,7 @@ void process_shaders() {
|
||||
for (const auto& tname : type_names) {
|
||||
// mul mat vec
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
|
||||
std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
|
||||
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
|
||||
|
||||
@@ -1172,9 +1172,6 @@ extern "C" {
|
||||
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
|
||||
|
||||
/// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n);
|
||||
|
||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
These templates can be updated with the following commands:
|
||||
|
||||
```bash
|
||||
./scripts/get_chat_template.py CohereForAI/c4ai-command-r-plus tool_use > models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja
|
||||
./scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 default > models/templates/CohereForAI-c4ai-command-r7b-12-2024-default.jinja
|
||||
./scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 rag > models/templates/CohereForAI-c4ai-command-r7b-12-2024-rag.jinja
|
||||
./scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use > models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja
|
||||
./scripts/get_chat_template.py deepseek-ai/DeepSeek-R1-Distill-Llama-8B > models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja
|
||||
./scripts/get_chat_template.py deepseek-ai/DeepSeek-R1-Distill-Qwen-32B > models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja
|
||||
./scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 > models/templates/fireworks-ai-llama-3-firefunction-v2.jinja
|
||||
./scripts/get_chat_template.py google/gemma-2-2b-it > models/templates/google-gemma-2-2b-it.jinja
|
||||
./scripts/get_chat_template.py meetkai/functionary-medium-v3. > models/templates/meetkai-functionary-medium-v3.jinja
|
||||
./scripts/get_chat_template.py meetkai/functionary-medium-v3.2 > models/templates/meetkai-functionary-medium-v3.2.jinja
|
||||
./scripts/get_chat_template.py meta-llama/Llama-3.1-8B-Instruct > models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja
|
||||
./scripts/get_chat_template.py meta-llama/Llama-3.2-3B-Instruct > models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja
|
||||
./scripts/get_chat_template.py meta-llama/Llama-3.3-70B-Instruct > models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja
|
||||
./scripts/get_chat_template.py microsoft/Phi-3.5-mini-instruct > models/templates/microsoft-Phi-3.5-mini-instruct.jinja
|
||||
./scripts/get_chat_template.py mistralai/Mistral-Nemo-Instruct-2407 > models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja
|
||||
./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja
|
||||
./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja
|
||||
./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
|
||||
```
|
||||
@@ -1 +1 @@
|
||||
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|><think>\n'}}{% endif %}
|
||||
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}
|
||||
@@ -1 +1,56 @@
|
||||
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|><think>\n'}}{% endif %}
|
||||
{% if not add_generation_prompt is defined %}
|
||||
{% set add_generation_prompt = false %}
|
||||
{% endif %}
|
||||
{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'system' %}
|
||||
{% set ns.system_prompt = message['content'] %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{bos_token}}
|
||||
{{ns.system_prompt}}
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'user' %}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{{'<|User|>' + message['content']}}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'assistant' and message['content'] is none %}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- for tool in message['tool_calls']%}
|
||||
{%- if not ns.is_first %}
|
||||
{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||
{%- set ns.is_first = true -%}
|
||||
{%- else %}
|
||||
{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||
{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'assistant' and message['content'] is not none %}
|
||||
{%- if ns.is_tool %}
|
||||
{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- else %}
|
||||
{% set content = message['content'] %}
|
||||
{% if '</think>' in content %}
|
||||
{% set content = content.split('</think>')[-1] %}
|
||||
{% endif %}
|
||||
{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'tool' %}
|
||||
{%- set ns.is_tool = true -%}
|
||||
{%- if ns.is_output_first %}
|
||||
{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
|
||||
{%- set ns.is_output_first = false %}
|
||||
{%- else %}
|
||||
{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor -%}
|
||||
{% if ns.is_tool %}
|
||||
{{'<|tool▁outputs▁end|>'}}
|
||||
{% endif %}
|
||||
{% if add_generation_prompt and not ns.is_tool %}
|
||||
{{'<|Assistant|>'}}
|
||||
{% endif %}
|
||||
@@ -1,76 +0,0 @@
|
||||
{%- if not add_generation_prompt is defined -%}
|
||||
{%- set add_generation_prompt = false -%}
|
||||
{%- endif -%}
|
||||
{%- set ns = namespace(is_first=false, is_tool_outputs=false, is_output_first=true, system_prompt='') -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'system' -%}
|
||||
{%- set ns.system_prompt = message['content'] -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{bos_token}}
|
||||
{%- if tools %}
|
||||
You can call any of the following function tools to satisfy the user's requests: {{tools | map(attribute='function') | tojson(indent=2)}}
|
||||
|
||||
Example function tool call syntax:
|
||||
|
||||
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>example_function_name
|
||||
```json
|
||||
{
|
||||
"arg1": "some_value"
|
||||
...
|
||||
}
|
||||
```
|
||||
<|tool▁call▁end|><|tool▁calls▁end|>
|
||||
|
||||
{% endif -%}
|
||||
{{ns.system_prompt}}
|
||||
{%- macro flush_tool_outputs() -%}
|
||||
{%- if ns.is_tool_outputs -%}
|
||||
{{- '<|tool▁outputs▁end|><|end▁of▁sentence|>' -}}
|
||||
{%- set ns.is_tool_outputs = false -%}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
{{- flush_tool_outputs() -}}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] != 'tool' -%}
|
||||
{{- flush_tool_outputs() -}}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{{- '<|User|>' + message['content'] + '<|end▁of▁sentence|>' -}}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'assistant' and message['content'] is none -%}
|
||||
{{- '<|Assistant|><|tool▁calls▁begin|>' -}}
|
||||
{%- set ns.is_first = true -%}
|
||||
{%- for tc in message['tool_calls'] -%}
|
||||
{%- if ns.is_first -%}
|
||||
{%- set ns.is_first = false -%}
|
||||
{%- else -%}
|
||||
{{- '\n' -}}
|
||||
{%- endif -%}
|
||||
{%- set tool_name = tc['function']['name'] -%}
|
||||
{%- set tool_args = tc['function']['arguments'] -%}
|
||||
{{- '<|tool▁call▁begin|>' + tc['type'] + '<|tool▁sep|>' + tool_name + '\n' + '```json' + '\n' + tool_args + '\n' + '```' + '<|tool▁call▁end|>' -}}
|
||||
{%- endfor -%}
|
||||
{{- '<|tool▁calls▁end|><|end▁of▁sentence|>' -}}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'assistant' and message['content'] is not none -%}
|
||||
{{- flush_tool_outputs() -}}
|
||||
{%- set content = message['content'] -%}
|
||||
{%- if '</think>' in content -%}
|
||||
{%- set content = content.split('</think>')[-1] -%}
|
||||
{%- endif -%}
|
||||
{{- '<|Assistant|>' + content + '<|end▁of▁sentence|>' -}}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'tool' -%}
|
||||
{%- set ns.is_tool_outputs = true -%}
|
||||
{%- if ns.is_output_first -%}
|
||||
{{- '<|tool▁outputs▁begin|>' -}}
|
||||
{%- set ns.is_output_first = false -%}
|
||||
{%- endif -%}
|
||||
{{- '\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>' -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- flush_tool_outputs() -}}
|
||||
{%- if add_generation_prompt and not ns.is_tool_outputs -%}
|
||||
{{- '<|Assistant|><think>\n' -}}
|
||||
{%- endif -%}
|
||||
5
scripts/get_chat_template.py
Executable file → Normal file
5
scripts/get_chat_template.py
Executable file → Normal file
@@ -7,8 +7,9 @@
|
||||
./scripts/get_chat_template.py model_id [variant]
|
||||
|
||||
Examples:
|
||||
./scripts/get_chat_template.py CohereForAI/c4ai-command-r-plus tool_use
|
||||
./scripts/get_chat_template.py microsoft/Phi-3.5-mini-instruct
|
||||
./scripts/get_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct
|
||||
./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use
|
||||
./scripts/get_chat_template.py meta-llama/Llama-3.2-3B-Instruct
|
||||
'''
|
||||
|
||||
import json
|
||||
|
||||
@@ -1 +1 @@
|
||||
98a61a0d0b43cba06c3ac1c603813639552a0701
|
||||
08b538031f7f944e84f472483ef5d26bf5190ead
|
||||
|
||||
@@ -1186,7 +1186,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||
return;
|
||||
}
|
||||
}
|
||||
LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str());
|
||||
LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`) (buffer: `%s`)\n", token, piece.c_str(), grammar.trigger_buffer.c_str());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ struct llama_kv_cache {
|
||||
bool can_shift = false;
|
||||
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_impl also uses it, so it
|
||||
// for a free KV slot. llama_decode_internal also uses it, so it
|
||||
// cannot be freely changed after a slot has been allocated.
|
||||
uint32_t head = 0;
|
||||
uint32_t size = 0;
|
||||
|
||||
@@ -1698,73 +1698,6 @@ struct llama_sampler * llama_sampler_init_penalties(
|
||||
);
|
||||
}
|
||||
|
||||
// top-n-sigma
|
||||
|
||||
struct llama_sampler_top_n_sigma {
|
||||
const float n;
|
||||
};
|
||||
|
||||
static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
|
||||
return "top-n-sigma";
|
||||
}
|
||||
|
||||
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
|
||||
|
||||
// find max logit and calculate mean
|
||||
float max = cur_p->data[0].logit;
|
||||
float logits_sum = 0;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (cur_p->data[i].logit > max) {
|
||||
max = cur_p->data[i].logit;
|
||||
}
|
||||
logits_sum += cur_p->data[i].logit;
|
||||
}
|
||||
float mean = logits_sum/cur_p->size;
|
||||
|
||||
// calculate standard deviation
|
||||
float acc = 0;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
acc += pow(cur_p->data[i].logit - mean, 2);
|
||||
}
|
||||
float std = sqrt(acc/cur_p->size);
|
||||
|
||||
//apply mask
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (cur_p->data[i].logit < max - (ctx->n * std)) {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
}
|
||||
llama_sampler_softmax_impl(cur_p);
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx;
|
||||
return llama_sampler_init_top_n_sigma(ctx->n);
|
||||
}
|
||||
|
||||
static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_top_n_sigma *) smpl->ctx;
|
||||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
|
||||
/* .name = */ llama_sampler_top_n_sigma_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_top_n_sigma_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_top_n_sigma_clone,
|
||||
/* .free = */ llama_sampler_top_n_sigma_free,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
|
||||
return llama_sampler_init(
|
||||
/* .iface = */ &llama_sampler_top_n_sigma_i,
|
||||
/* .ctx = */ new llama_sampler_top_n_sigma {
|
||||
/* .n = */ n,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// DRY
|
||||
|
||||
struct llama_sampler_dry {
|
||||
|
||||
@@ -24,10 +24,7 @@ static common_chat_msg msg_from_json(const json & message) {
|
||||
ret.content = message.at("content");
|
||||
}
|
||||
if (message.contains("tool_plan")) {
|
||||
ret.reasoning_content = message.at("tool_plan");
|
||||
}
|
||||
if (message.contains("reasoning_content")) {
|
||||
ret.reasoning_content = message.at("reasoning_content");
|
||||
ret.tool_plan = message.at("tool_plan");
|
||||
}
|
||||
auto has_tool_calls = message.contains("tool_calls");
|
||||
if (has_tool_calls) {
|
||||
@@ -108,7 +105,6 @@ static std::string dump(const json & j) {
|
||||
static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
|
||||
assert_equals(expected.role, actual.role);
|
||||
assert_equals(expected.content, actual.content);
|
||||
assert_equals(expected.reasoning_content, actual.reasoning_content);
|
||||
assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
|
||||
for (size_t i = 0; i < expected.tool_calls.size(); i++) {
|
||||
const auto & expected_tool_call = expected.tool_calls[i];
|
||||
@@ -180,15 +176,13 @@ struct delta_data {
|
||||
|
||||
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
|
||||
const json & user_message, const json & delta_message, const json & tools,
|
||||
const json & tool_choice,
|
||||
bool think = false) {
|
||||
const json & tool_choice) {
|
||||
common_chat_inputs inputs;
|
||||
inputs.parallel_tool_calls = true;
|
||||
inputs.messages = json::array();
|
||||
inputs.messages.push_back(user_message);
|
||||
inputs.tools = tools;
|
||||
inputs.tool_choice = tool_choice;
|
||||
inputs.extract_reasoning = think;
|
||||
auto params_prefix = common_chat_params_init(tmpl, inputs);
|
||||
|
||||
inputs.messages.push_back(delta_message);
|
||||
@@ -198,24 +192,17 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
|
||||
std::string prefix = params_prefix.prompt;
|
||||
std::string full = params_full.prompt;
|
||||
|
||||
// Check full starts with prefix
|
||||
if (full.find(prefix) != 0) {
|
||||
fprintf(stderr, "Full:\n%s\n\nPrefix:\n%s\n\n", full.c_str(), prefix.c_str());
|
||||
throw std::runtime_error("Full message does not start with prefix");
|
||||
}
|
||||
|
||||
if (full == prefix) {
|
||||
throw std::runtime_error("Full message is the same as the prefix");
|
||||
}
|
||||
|
||||
size_t common_prefix_length = 0;
|
||||
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
|
||||
if (prefix[i] != full[i]) {
|
||||
break;
|
||||
}
|
||||
if (prefix[i] == '<') {
|
||||
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
|
||||
// but it removes thinking tags for past messages.
|
||||
// The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
|
||||
continue;
|
||||
}
|
||||
common_prefix_length = i + 1;
|
||||
}
|
||||
auto delta = full.substr(common_prefix_length);
|
||||
auto delta = full.substr(prefix.size());
|
||||
|
||||
// Strip end tokens
|
||||
for (const auto & end_token : end_tokens) {
|
||||
@@ -236,9 +223,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
|
||||
*/
|
||||
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
|
||||
const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
|
||||
bool expect_grammar_triggered = true,
|
||||
bool test_grammar_if_triggered = true,
|
||||
bool think = false) {
|
||||
bool expect_grammar_triggered = true) {
|
||||
common_chat_msg expected_msg = msg_from_json(test_message);
|
||||
|
||||
auto user_message = json{
|
||||
@@ -247,7 +232,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
||||
};
|
||||
|
||||
for (const auto & tool_choice : json({ "auto", "required" })) {
|
||||
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice, think);
|
||||
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice);
|
||||
if (!expected_delta.empty()) {
|
||||
assert_equals(expected_delta, data.delta);
|
||||
}
|
||||
@@ -289,7 +274,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
||||
assert_equals(expect_grammar_triggered, grammar_triggered);
|
||||
}
|
||||
|
||||
if (grammar_triggered && test_grammar_if_triggered && !match_string(constrained, grammar.get())) {
|
||||
if (grammar_triggered && !match_string(constrained, grammar.get())) {
|
||||
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
|
||||
"\n\nGrammar: " + data.params.grammar);
|
||||
}
|
||||
@@ -298,33 +283,16 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
||||
}
|
||||
|
||||
static void test_template_output_parsers() {
|
||||
json message_user {
|
||||
{ "role", "user" },
|
||||
{ "content", "Hey there!" },
|
||||
};
|
||||
json message_assist {
|
||||
json text_message {
|
||||
{ "role", "assistant" },
|
||||
{ "content", "Hello, world!\nWhat's up?" },
|
||||
};
|
||||
json message_assist_thoughts_unparsed_think {
|
||||
{ "role", "assistant" },
|
||||
{ "content", "<think>I'm thinking</think>Hello, world!\nWhat's up?" },
|
||||
};
|
||||
json message_assist_thoughts_unparsed_r7b {
|
||||
{ "role", "assistant" },
|
||||
{ "content", "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?" },
|
||||
};
|
||||
json message_assist_thoughts {
|
||||
{ "role", "assistant" },
|
||||
{ "content", "Hello, world!\nWhat's up?" },
|
||||
{ "reasoning_content", "I'm thinking" },
|
||||
};
|
||||
json tool_calls = json::array({{
|
||||
{ "type", "function" },
|
||||
{ "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
|
||||
}});
|
||||
|
||||
json message_assist_call {
|
||||
json tool_call_message {
|
||||
{ "role", "assistant"},
|
||||
{ "content", {}},
|
||||
{ "tool_calls", {
|
||||
@@ -337,34 +305,7 @@ static void test_template_output_parsers() {
|
||||
},
|
||||
}},
|
||||
};
|
||||
json message_assist_call_thoughts = {
|
||||
{ "role", "assistant" },
|
||||
{ "content", nullptr },
|
||||
{ "reasoning_content", "I'm\nthinking" },
|
||||
{ "tool_calls", {
|
||||
{
|
||||
{ "type", "function" },
|
||||
{ "function", {
|
||||
{ "name", "special_function" },
|
||||
{ "arguments", "{\"arg1\": 1}" },
|
||||
}},
|
||||
},
|
||||
}},
|
||||
};
|
||||
json message_assist_call_thoughts_unparsed = {
|
||||
{ "role", "assistant" },
|
||||
{ "content", "<think>I'm\nthinking</think>" },
|
||||
{ "tool_calls", {
|
||||
{
|
||||
{ "type", "function" },
|
||||
{ "function", {
|
||||
{ "name", "special_function" },
|
||||
{ "arguments", "{\"arg1\": 1}" },
|
||||
}},
|
||||
},
|
||||
}},
|
||||
};
|
||||
json message_assist_call_id {
|
||||
json tool_call_message_with_id {
|
||||
{ "role", "assistant"},
|
||||
{ "content", {}},
|
||||
{ "tool_calls", {
|
||||
@@ -381,9 +322,10 @@ static void test_template_output_parsers() {
|
||||
{ "content", {} },
|
||||
{ "tool_calls", tool_calls }
|
||||
};
|
||||
json message_assist_call_idx {
|
||||
json tool_call_plan_message_with_idx {
|
||||
{ "role", "assistant"},
|
||||
{ "content", {}},
|
||||
{ "tool_plan", "I'm not so sure"},
|
||||
{ "tool_calls", {
|
||||
{
|
||||
{ "type", "function" },
|
||||
@@ -399,10 +341,8 @@ static void test_template_output_parsers() {
|
||||
{ "content", {} },
|
||||
{ "tool_calls", tool_calls }
|
||||
};
|
||||
json message_assist_call_tool_plan_idx = message_assist_call_idx;
|
||||
message_assist_call_tool_plan_idx["tool_plan"] = "I'm thinking";
|
||||
|
||||
auto python_message_assist_call = json{
|
||||
auto python_tool_call_message = json{
|
||||
{ "role", "assistant" },
|
||||
{ "content", {} },
|
||||
{ "tool_calls", json{ {
|
||||
@@ -417,7 +357,7 @@ static void test_template_output_parsers() {
|
||||
} },
|
||||
} } }
|
||||
};
|
||||
auto code_interpreter_message_assist_call = json{
|
||||
auto code_interpreter_tool_call_message = json{
|
||||
{ "role", "assistant" },
|
||||
{ "content", {} },
|
||||
{ "tool_calls", json{ {
|
||||
@@ -434,27 +374,17 @@ static void test_template_output_parsers() {
|
||||
};
|
||||
|
||||
common_chat_inputs inputs_no_tools;
|
||||
inputs_no_tools.messages = json::array({message_user});
|
||||
inputs_no_tools.extract_reasoning = false;
|
||||
inputs_no_tools.messages = {
|
||||
{ { "role", "user" }, { "content", "Hey\nThere" } }
|
||||
};
|
||||
|
||||
common_chat_inputs inputs_no_tools_think;
|
||||
inputs_no_tools_think.messages = json::array({message_user});
|
||||
inputs_no_tools_think.extract_reasoning = true;
|
||||
common_chat_inputs inputs_tools = inputs_no_tools;
|
||||
inputs_tools.tools = json::array();
|
||||
inputs_tools.tools.push_back(special_function_tool);
|
||||
|
||||
common_chat_inputs inputs_tools;
|
||||
inputs_tools.messages = json::array({message_user});
|
||||
inputs_tools.tools = json::array({special_function_tool});
|
||||
inputs_tools.extract_reasoning = false;
|
||||
|
||||
common_chat_inputs inputs_tools_think;
|
||||
inputs_tools_think.messages = json::array({message_user});
|
||||
inputs_tools_think.tools = json::array({special_function_tool});
|
||||
inputs_tools_think.extract_reasoning = true;
|
||||
|
||||
common_chat_inputs inputs_tools_builtin;
|
||||
inputs_tools_builtin.messages = json::array({message_user});
|
||||
inputs_tools_builtin.tools = json::array({python_tool});
|
||||
inputs_tools_builtin.extract_reasoning = false;
|
||||
common_chat_inputs inputs_tools_builtin = inputs_no_tools;
|
||||
inputs_tools_builtin.tools = json::array();
|
||||
inputs_tools_builtin.tools.push_back(python_tool);
|
||||
|
||||
{
|
||||
// Not supported yet
|
||||
@@ -465,53 +395,15 @@ static void test_template_output_parsers() {
|
||||
const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
|
||||
|
||||
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_no_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
assert_msg_equals(msg_from_json(message_assist),
|
||||
common_chat_parse(
|
||||
"Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
||||
assert_msg_equals(msg_from_json(message_assist),
|
||||
common_chat_parse(
|
||||
"Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
||||
assert_msg_equals(msg_from_json(message_assist),
|
||||
common_chat_parse(
|
||||
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b),
|
||||
common_chat_parse(
|
||||
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
|
||||
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b),
|
||||
common_chat_parse(
|
||||
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
|
||||
"Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
||||
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts),
|
||||
common_chat_parse(
|
||||
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
|
||||
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING));
|
||||
|
||||
test_template(tmpl, end_tokens, message_assist_call_idx, tools,
|
||||
"<|START_THINKING|><|END_THINKING|>"
|
||||
test_template(tmpl, end_tokens, tool_call_plan_message_with_idx, tools,
|
||||
"<|START_THINKING|>I'm not so sure<|END_THINKING|>"
|
||||
"<|START_ACTION|>[\n"
|
||||
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
|
||||
"]<|END_ACTION|>");
|
||||
test_template(tmpl, end_tokens, message_assist_call_tool_plan_idx, tools,
|
||||
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
|
||||
"<|START_ACTION|>[\n"
|
||||
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
|
||||
"]<|END_ACTION|>",
|
||||
/* expect_grammar_triggered= */ true,
|
||||
/* test_grammar_if_triggered= */ true,
|
||||
/* think= */ true);
|
||||
test_template(tmpl, end_tokens, message_assist, tools,
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"<|START_RESPONSE|>Hello, world!\n"
|
||||
"What's up?<|END_RESPONSE|>",
|
||||
/* expect_grammar_triggered= */ false);
|
||||
@@ -531,12 +423,12 @@ static void test_template_output_parsers() {
|
||||
|
||||
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
||||
|
||||
assert_msg_equals(msg_from_json(message_assist),
|
||||
assert_msg_equals(msg_from_json(text_message),
|
||||
common_chat_parse("{\n"
|
||||
" \"response\": \"Hello, world!\\nWhat's up?\"\n"
|
||||
"}",
|
||||
common_chat_params_init(tmpl, inputs_tools).format));
|
||||
test_template(tmpl, end_tokens, message_assist_call_id, tools,
|
||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||
"{\n"
|
||||
" \"tool_calls\": [\n"
|
||||
" {\n"
|
||||
@@ -556,9 +448,9 @@ static void test_template_output_parsers() {
|
||||
|
||||
assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(
|
||||
tmpl, end_tokens, message_assist_call_id, tools,
|
||||
tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
|
||||
}
|
||||
{
|
||||
@@ -581,12 +473,12 @@ static void test_template_output_parsers() {
|
||||
inputs_tools)
|
||||
.format);
|
||||
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"<tool_call>\n"
|
||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</tool_call>");
|
||||
test_template(tmpl, end_tokens, python_message_assist_call, tools,
|
||||
test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
||||
"<tool_call>\n"
|
||||
"{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
|
||||
"</tool_call>");
|
||||
@@ -606,12 +498,12 @@ static void test_template_output_parsers() {
|
||||
inputs_tools_builtin)
|
||||
.format);
|
||||
|
||||
// test_template(tmpl, end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, code_interpreter_message_assist_call, llama_3_1_tools,
|
||||
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
||||
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
|
||||
test_template(tmpl, end_tokens, python_message_assist_call, tools,
|
||||
test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
||||
"<|python_tag|>python.call(code=\"print('hey')\")");
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
}
|
||||
{
|
||||
@@ -621,8 +513,8 @@ static void test_template_output_parsers() {
|
||||
|
||||
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
}
|
||||
{
|
||||
@@ -633,8 +525,8 @@ static void test_template_output_parsers() {
|
||||
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||
common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"<function=special_function>{\"arg1\": 1}</function>");
|
||||
}
|
||||
{
|
||||
@@ -645,12 +537,12 @@ static void test_template_output_parsers() {
|
||||
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, message_assist, {},
|
||||
test_template(tmpl, end_tokens, text_message, {},
|
||||
"all\n"
|
||||
"Hello, world!\n"
|
||||
"What's up?",
|
||||
/* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"special_function\n"
|
||||
"{\"arg1\": 1}");
|
||||
}
|
||||
@@ -661,79 +553,23 @@ static void test_template_output_parsers() {
|
||||
|
||||
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
|
||||
}
|
||||
{
|
||||
// Original DeepSeek R1 template. Leaves <|tool▁calls▁begin|> and others unclosed. Our logic fixes the prompt.
|
||||
const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"),
|
||||
"<s>", "</s>");
|
||||
std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
|
||||
|
||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think),
|
||||
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts),
|
||||
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts),
|
||||
// Latest template update (ast of 20250209) adds a trailing <think>\n if add_generation_prompt is true.
|
||||
common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
|
||||
// test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
// "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
||||
// "```json\n"
|
||||
// "{\"arg1\": 1}\n"
|
||||
// // Look what's not here: <|tool▁calls▁end|> (also missing the <|end▁of▁sentence|>, but that is removed lazily by the test's delta logic)
|
||||
// "```<|tool▁call▁end|>",
|
||||
// /* expect_grammar_triggered= */ true,
|
||||
// /* test_grammar_if_triggered= */ false);
|
||||
}
|
||||
{
|
||||
// Replacement DeepSeek R1 template. Makes the Distill Qwen 7B/32B models happy to call tools and all.
|
||||
const common_chat_template tmpl(read_file("models/templates/llama-cpp-deepseek-r1.jinja"),
|
||||
"<s>", "</s>");
|
||||
std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
|
||||
|
||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format);
|
||||
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think),
|
||||
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts),
|
||||
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
|
||||
|
||||
assert_msg_equals(msg_from_json(message_assist_call_thoughts_unparsed),
|
||||
common_chat_parse(
|
||||
"<think>I'm\nthinking</think>\n\n"
|
||||
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
||||
"```json\n"
|
||||
"{\"arg1\": 1}\n"
|
||||
"```<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
||||
assert_msg_equals(msg_from_json(message_assist_call_thoughts),
|
||||
common_chat_parse(
|
||||
"<think>I'm\nthinking</think>\n\n"
|
||||
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
||||
"```json\n"
|
||||
"{\"arg1\": 1}\n"
|
||||
"```<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
||||
"```json\n"
|
||||
"{\"arg1\": 1}\n"
|
||||
"```<|tool▁call▁end|><|tool▁calls▁end|>");
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
||||
"```json\n"
|
||||
"{\"arg1\": 1}\n"
|
||||
"```<|tool▁call▁end|>");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -750,20 +586,16 @@ int main(int argc, char ** argv) {
|
||||
std::cout << "|----------|--------|\n";
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
try {
|
||||
std::string path = argv[i];
|
||||
if (path.rfind(".jinja") != path.size() - 6) {
|
||||
std::cerr << "Skipping non-jinja file: " << path << std::endl;
|
||||
continue;
|
||||
}
|
||||
common_chat_template tmpl(read_file(path), "", "");
|
||||
auto parts = string_split(path, "/");
|
||||
auto name = parts[parts.size() - 1];
|
||||
auto format = common_chat_format_name(common_chat_params_init(tmpl, inputs).format);
|
||||
std::cout << "| " << name << " | " << format << " |\n";
|
||||
} catch (const std::exception & e) {
|
||||
std::cerr << "Failed to process " << argv[i] << ": " << e.what() << std::endl;
|
||||
std::string path = argv[i];
|
||||
if (path.rfind(".jinja") != path.size() - 6) {
|
||||
std::cerr << "Skipping non-jinja file: " << path << std::endl;
|
||||
continue;
|
||||
}
|
||||
common_chat_template tmpl(read_file(path), "", "");
|
||||
auto parts = string_split(path, "/");
|
||||
auto name = parts[parts.size() - 1];
|
||||
std::cout << "| " << name << " | " << common_chat_format_name(common_chat_params_init(tmpl, inputs).format)
|
||||
<< " |\n";
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
|
||||
@@ -181,17 +181,6 @@ static void test_dry(
|
||||
tester.check();
|
||||
}
|
||||
|
||||
static void test_top_n_sigma(const std::vector<float> & probs, const std::vector<float> & probs_expected, int n) {
|
||||
sampler_tester tester(probs, probs_expected);
|
||||
|
||||
DUMP(&tester.cur_p);
|
||||
tester.apply(llama_sampler_init_top_n_sigma(n));
|
||||
tester.apply(llama_sampler_init_dist (0));
|
||||
DUMP(&tester.cur_p);
|
||||
|
||||
tester.check();
|
||||
}
|
||||
|
||||
static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
|
||||
) {
|
||||
sampler_tester tester(n_vocab);
|
||||
@@ -359,10 +348,6 @@ int main(void) {
|
||||
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {});
|
||||
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
|
||||
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1.00f);
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.00f);
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3.00f);
|
||||
|
||||
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
|
||||
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
|
||||
test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f);
|
||||
|
||||
Reference in New Issue
Block a user