mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-05 13:53:23 +02:00
Compare commits
43 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
08d5986290 | ||
|
|
651adf4b66 | ||
|
|
8303e8b0fb | ||
|
|
7ad0779f5d | ||
|
|
f777a73e18 | ||
|
|
af7747c95a | ||
|
|
a28e0d5eb1 | ||
|
|
36c258ee92 | ||
|
|
f3e64859ed | ||
|
|
5fa07c2f93 | ||
|
|
335eb04a91 | ||
|
|
cf756d6e0a | ||
|
|
d70908421f | ||
|
|
de8b5a3624 | ||
|
|
51f311e057 | ||
|
|
586d5fe6eb | ||
|
|
ecc8e3aeff | ||
|
|
0b3863ff95 | ||
|
|
ee02ad02c5 | ||
|
|
c392e5094d | ||
|
|
c5d91a7400 | ||
|
|
4806498bf1 | ||
|
|
0d559580a0 | ||
|
|
d04e7163c8 | ||
|
|
d07c621393 | ||
|
|
abd4d0bc4f | ||
|
|
9626d9351a | ||
|
|
b58934c183 | ||
|
|
63e489c025 | ||
|
|
63ac128563 | ||
|
|
5137da7b8c | ||
|
|
09aaf4f1f5 | ||
|
|
73e2ed3ce3 | ||
|
|
f7b1116af1 | ||
|
|
c4d29baf32 | ||
|
|
2eea03d86a | ||
|
|
0f2bbe6564 | ||
|
|
fe163d5bf3 | ||
|
|
818a340ea8 | ||
|
|
bf42a23d0a | ||
|
|
c2ea16f260 | ||
|
|
6dde178248 | ||
|
|
fc10c38ded |
20
.github/workflows/build.yml
vendored
20
.github/workflows/build.yml
vendored
@@ -173,7 +173,15 @@ jobs:
|
||||
name: llama-bin-macos-x64.zip
|
||||
|
||||
ubuntu-cpu-cmake:
|
||||
runs-on: ubuntu-22.04
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- build: 'x64'
|
||||
os: ubuntu-22.04
|
||||
- build: 'arm64'
|
||||
os: ubuntu-22.04-arm
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -239,14 +247,14 @@ jobs:
|
||||
run: |
|
||||
cp LICENSE ./build/bin/
|
||||
cp examples/run/linenoise.cpp/LICENSE ./build/bin/LICENSE.linenoise.cpp
|
||||
zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.zip ./build/bin/*
|
||||
zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip ./build/bin/*
|
||||
|
||||
- name: Upload artifacts
|
||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.zip
|
||||
name: llama-bin-ubuntu-x64.zip
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip
|
||||
name: llama-bin-ubuntu-${{ matrix.build }}.zip
|
||||
|
||||
ubuntu-latest-cmake-sanitizer:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -374,6 +382,8 @@ jobs:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
@@ -1373,8 +1383,10 @@ jobs:
|
||||
|
||||
needs:
|
||||
- ubuntu-cpu-cmake
|
||||
- ubuntu-22-cmake-vulkan
|
||||
- windows-latest-cmake
|
||||
- windows-2019-cmake-cuda
|
||||
- windows-latest-cmake-sycl
|
||||
- windows-latest-cmake-hip-release
|
||||
- macOS-latest-cmake-arm64
|
||||
- macOS-latest-cmake-x64
|
||||
|
||||
2
.github/workflows/docker.yml
vendored
2
.github/workflows/docker.yml
vendored
@@ -51,6 +51,8 @@ jobs:
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
with:
|
||||
image: tonistiigi/binfmt:qemu-v7.0.0-28
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -98,6 +98,7 @@ examples/server/*.css.hpp
|
||||
examples/server/*.html.hpp
|
||||
examples/server/*.js.hpp
|
||||
examples/server/*.mjs.hpp
|
||||
examples/server/*.gz.hpp
|
||||
!build_64.sh
|
||||
!examples/*.bat
|
||||
!examples/*/*.kts
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
# Pull requests (for contributors)
|
||||
|
||||
- llama.cpp uses the ggml tensor library for model evaluation. If you are unfamiliar with ggml, consider taking a look at the [examples in the ggml repository](https://github.com/ggml-org/ggml/tree/master/examples/). [simple](https://github.com/ggml-org/ggml/tree/master/examples/simple) shows the bare minimum for using ggml. [gpt-2](https://github.com/ggml-org/ggml/tree/master/examples/gpt-2) has minimal implementations for language model inference using GPT-2. [mnist](https://github.com/ggml-org/ggml/tree/master/examples/mnist) demonstrates how to train and evaluate a simple image classifier
|
||||
- Test your changes:
|
||||
- Execute [the full CI locally on your machine](ci/README.md) before publishing
|
||||
- Verify that the perplexity and the performance are not affected negatively by your changes (use `llama-perplexity` and `llama-bench`)
|
||||
- If you modified the `ggml` source, run the `test-backend-ops` tool to check whether different backend implementations of the `ggml` operators produce consistent results (this requires access to at least two different `ggml` backends)
|
||||
- If you modified a `ggml` operator or added a new one, add the corresponding test cases to `test-backend-ops`
|
||||
- Create separate PRs for each feature or fix. Avoid combining unrelated changes in a single PR
|
||||
- Consider allowing write access to your branch for faster reviews, as reviewers can push commits directly
|
||||
- If your PR becomes stale, don't hesitate to ping the maintainers in the comments
|
||||
|
||||
|
||||
16
Makefile
16
Makefile
@@ -680,6 +680,10 @@ ifdef GGML_CUDA_CCBIN
|
||||
MK_NVCCFLAGS += -ccbin $(GGML_CUDA_CCBIN)
|
||||
endif # GGML_CUDA_CCBIN
|
||||
|
||||
ifdef GGML_CUDA_NO_FA
|
||||
MK_NVCCFLAGS += -DGGML_CUDA_NO_FA
|
||||
endif # GGML_CUDA_NO_FA
|
||||
|
||||
ifdef GGML_CUDA_FA_ALL_QUANTS
|
||||
MK_NVCCFLAGS += -DGGML_CUDA_FA_ALL_QUANTS
|
||||
endif # GGML_CUDA_FA_ALL_QUANTS
|
||||
@@ -800,6 +804,10 @@ ifdef GGML_CUDA_NO_PEER_COPY
|
||||
HIPFLAGS += -DGGML_CUDA_NO_PEER_COPY
|
||||
endif # GGML_CUDA_NO_PEER_COPY
|
||||
|
||||
ifdef GGML_CUDA_NO_FA
|
||||
HIPFLAGS += -DGGML_CUDA_NO_FA
|
||||
endif # GGML_CUDA_NO_FA
|
||||
|
||||
OBJ_GGML_EXT += ggml/src/ggml-cuda/ggml-cuda.o
|
||||
OBJ_GGML_EXT += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu))
|
||||
OBJ_GGML_EXT += $(OBJ_CUDA_TMPL)
|
||||
@@ -847,7 +855,7 @@ ifdef GGML_MUSA
|
||||
CXX := $(MUSA_PATH)/bin/clang++
|
||||
MCC := $(CCACHE) $(MUSA_PATH)/bin/mcc
|
||||
|
||||
MUSAFLAGS = -x musa -mtgpu
|
||||
MUSAFLAGS = -fsigned-char -x musa -mtgpu
|
||||
MUSAFLAGS += $(foreach arch,$(subst ;, ,$(MUSA_ARCHITECTURES)),--cuda-gpu-arch=mp_$(arch))
|
||||
|
||||
ifdef GGML_CUDA_FORCE_MMQ
|
||||
@@ -876,6 +884,10 @@ ifdef GGML_CUDA_NO_PEER_COPY
|
||||
MUSAFLAGS += -DGGML_CUDA_NO_PEER_COPY
|
||||
endif # GGML_CUDA_NO_PEER_COPY
|
||||
|
||||
ifdef GGML_CUDA_NO_FA
|
||||
MUSAFLAGS += -DGGML_CUDA_NO_FA
|
||||
endif # GGML_CUDA_NO_FA
|
||||
|
||||
ifdef GGML_CUDA_FA_ALL_QUANTS
|
||||
MUSAFLAGS += -DGGML_CUDA_FA_ALL_QUANTS
|
||||
endif # GGML_CUDA_FA_ALL_QUANTS
|
||||
@@ -1364,7 +1376,7 @@ llama-server: \
|
||||
examples/server/index.html.hpp \
|
||||
examples/server/loading.html.hpp \
|
||||
common/chat.cpp \
|
||||
common/chat.hpp \
|
||||
common/chat.h \
|
||||
common/chat-template.hpp \
|
||||
common/json.hpp \
|
||||
common/minja.hpp \
|
||||
|
||||
@@ -57,8 +57,7 @@ add_library(${TARGET} STATIC
|
||||
arg.h
|
||||
base64.hpp
|
||||
chat.cpp
|
||||
chat.hpp
|
||||
chat-template.hpp
|
||||
chat.h
|
||||
common.cpp
|
||||
common.h
|
||||
console.cpp
|
||||
@@ -68,7 +67,8 @@ add_library(${TARGET} STATIC
|
||||
llguidance.cpp
|
||||
log.cpp
|
||||
log.h
|
||||
minja.hpp
|
||||
minja/chat-template.hpp
|
||||
minja/minja.hpp
|
||||
ngram-cache.cpp
|
||||
ngram-cache.h
|
||||
sampling.cpp
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
#include "chat.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
@@ -2247,7 +2248,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_env("LLAMA_LOG_VERBOSITY"));
|
||||
add_opt(common_arg(
|
||||
{"--log-prefix"},
|
||||
"Enable prefx in log messages",
|
||||
"Enable prefix in log messages",
|
||||
[](common_params &) {
|
||||
common_log_set_prefix(common_log_main(), true);
|
||||
}
|
||||
@@ -2501,5 +2502,53 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--fim-qwen-1.5b-default"},
|
||||
string_format("use default Qwen 2.5 Coder 1.5B (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-Q8_0-GGUF";
|
||||
params.hf_file = "qwen2.5-coder-1.5b-q8_0.gguf";
|
||||
params.port = 8012;
|
||||
params.n_gpu_layers = 99;
|
||||
params.flash_attn = true;
|
||||
params.n_ubatch = 1024;
|
||||
params.n_batch = 1024;
|
||||
params.n_ctx = 0;
|
||||
params.n_cache_reuse = 256;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--fim-qwen-3b-default"},
|
||||
string_format("use default Qwen 2.5 Coder 3B (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.hf_repo = "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF";
|
||||
params.hf_file = "qwen2.5-coder-3b-q8_0.gguf";
|
||||
params.port = 8012;
|
||||
params.n_gpu_layers = 99;
|
||||
params.flash_attn = true;
|
||||
params.n_ubatch = 1024;
|
||||
params.n_batch = 1024;
|
||||
params.n_ctx = 0;
|
||||
params.n_cache_reuse = 256;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--fim-qwen-7b-default"},
|
||||
string_format("use default Qwen 2.5 Coder 7B (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
|
||||
params.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
|
||||
params.port = 8012;
|
||||
params.n_gpu_layers = 99;
|
||||
params.flash_attn = true;
|
||||
params.n_ubatch = 1024;
|
||||
params.n_batch = 1024;
|
||||
params.n_ctx = 0;
|
||||
params.n_cache_reuse = 256;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
return ctx_arg;
|
||||
}
|
||||
|
||||
730
common/chat.cpp
730
common/chat.cpp
@@ -1,8 +1,433 @@
|
||||
#include "chat.hpp"
|
||||
#include "chat-template.hpp"
|
||||
#include "chat.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "minja.hpp"
|
||||
#include "minja/chat-template.hpp"
|
||||
#include "minja/minja.hpp"
|
||||
|
||||
#include <optional>
|
||||
|
||||
typedef minja::chat_template common_chat_template;
|
||||
|
||||
struct common_chat_templates {
|
||||
bool has_explicit_template; // Model had builtin template or template overridde was specified.
|
||||
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
|
||||
std::unique_ptr<common_chat_template> template_tool_use;
|
||||
};
|
||||
|
||||
struct templates_params {
|
||||
json messages;
|
||||
json tools;
|
||||
common_chat_tool_choice tool_choice;
|
||||
json json_schema;
|
||||
bool parallel_tool_calls;
|
||||
bool stream;
|
||||
std::string grammar;
|
||||
bool add_generation_prompt = true;
|
||||
bool extract_reasoning = true;
|
||||
};
|
||||
|
||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
|
||||
if (tool_choice == "auto") {
|
||||
return COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
}
|
||||
if (tool_choice == "none") {
|
||||
return COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
}
|
||||
if (tool_choice == "required") {
|
||||
return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
}
|
||||
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
|
||||
std::vector<common_chat_msg> msgs;
|
||||
|
||||
try {
|
||||
|
||||
if (!messages.is_array()) {
|
||||
throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump());
|
||||
}
|
||||
|
||||
for (const auto & message : messages) {
|
||||
if (!message.is_object()) {
|
||||
throw std::runtime_error("Expected 'message' to be an object, got " + message.dump());
|
||||
}
|
||||
|
||||
common_chat_msg msg;
|
||||
if (!message.contains("role")) {
|
||||
throw std::runtime_error("Missing 'role' in message: " + message.dump());
|
||||
}
|
||||
msg.role = message.at("role");
|
||||
|
||||
if (message.contains("content")) {
|
||||
const auto & content = message.at("content");
|
||||
if (content.is_string()) {
|
||||
msg.content = content;
|
||||
} else if (content.is_array()) {
|
||||
for (const auto & part : content) {
|
||||
if (!part.contains("type")) {
|
||||
throw std::runtime_error("Missing content part type: " + part.dump());
|
||||
}
|
||||
const auto & type = part.at("type");
|
||||
if (type != "text") {
|
||||
throw std::runtime_error("Unsupported content part type: " + type.dump());
|
||||
}
|
||||
common_chat_msg_content_part msg_part;
|
||||
msg_part.type = type;
|
||||
msg_part.text = part.at("text");
|
||||
msg.content_parts.push_back(msg_part);
|
||||
}
|
||||
} else if (!content.is_null()) {
|
||||
throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Expected 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
|
||||
}
|
||||
if (message.contains("reasoning_content")) {
|
||||
msg.reasoning_content = message.at("reasoning_content");
|
||||
}
|
||||
if (message.contains("name")) {
|
||||
msg.tool_name = message.at("name");
|
||||
}
|
||||
if (message.contains("tool_call_id")) {
|
||||
msg.tool_call_id = message.at("tool_call_id");
|
||||
}
|
||||
if (message.contains("tool_calls")) {
|
||||
for (const auto & tool_call : message.at("tool_calls")) {
|
||||
common_chat_tool_call tc;
|
||||
if (!tool_call.contains("type")) {
|
||||
throw std::runtime_error("Missing tool call type: " + tool_call.dump());
|
||||
}
|
||||
const auto & type = tool_call.at("type");
|
||||
if (type != "function") {
|
||||
throw std::runtime_error("Unsupported tool call type: " + tool_call.dump());
|
||||
}
|
||||
if (!tool_call.contains("function")) {
|
||||
throw std::runtime_error("Missing tool call function: " + tool_call.dump());
|
||||
}
|
||||
const auto & fc = tool_call.at("function");
|
||||
if (!fc.contains("name")) {
|
||||
throw std::runtime_error("Missing tool call name: " + tool_call.dump());
|
||||
}
|
||||
tc.name = fc.at("name");
|
||||
tc.arguments = fc.at("arguments");
|
||||
if (tool_call.contains("id")) {
|
||||
tc.id = tool_call.at("id");
|
||||
}
|
||||
msg.tool_calls.push_back(tc);
|
||||
}
|
||||
}
|
||||
|
||||
msgs.push_back(msg);
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2));
|
||||
}
|
||||
|
||||
return msgs;
|
||||
}
|
||||
|
||||
template <>
|
||||
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
|
||||
json messages = json::array();
|
||||
for (const auto & msg : msgs) {
|
||||
if (!msg.content.empty() && !msg.content_parts.empty()) {
|
||||
throw std::runtime_error("Cannot specify both content and content_parts");
|
||||
}
|
||||
json jmsg {
|
||||
{"role", msg.role},
|
||||
};
|
||||
if (!msg.content.empty()) {
|
||||
jmsg["content"] = msg.content;
|
||||
} else if (!msg.content_parts.empty()) {
|
||||
if (concat_typed_text) {
|
||||
std::string text;
|
||||
for (const auto & part : msg.content_parts) {
|
||||
if (part.type != "text") {
|
||||
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
|
||||
continue;
|
||||
}
|
||||
if (!text.empty()) {
|
||||
text += '\n';
|
||||
}
|
||||
text += part.text;
|
||||
}
|
||||
jmsg["content"] = text;
|
||||
} else {
|
||||
auto & parts = jmsg["content"] = json::array();
|
||||
for (const auto & part : msg.content_parts) {
|
||||
parts.push_back({
|
||||
{"type", part.type},
|
||||
{"text", part.text},
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
jmsg["content"] = json(); // null
|
||||
}
|
||||
if (!msg.reasoning_content.empty()) {
|
||||
jmsg["reasoning_content"] = msg.reasoning_content;
|
||||
}
|
||||
if (!msg.tool_name.empty()) {
|
||||
jmsg["name"] = msg.tool_name;
|
||||
}
|
||||
if (!msg.tool_call_id.empty()) {
|
||||
jmsg["tool_call_id"] = msg.tool_call_id;
|
||||
}
|
||||
if (!msg.tool_calls.empty()) {
|
||||
auto & tool_calls = jmsg["tool_calls"] = json::array();
|
||||
for (const auto & tool_call : msg.tool_calls) {
|
||||
json tc {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", tool_call.name},
|
||||
{"arguments", tool_call.arguments},
|
||||
}},
|
||||
};
|
||||
if (!tool_call.id.empty()) {
|
||||
tc["id"] = tool_call.id;
|
||||
}
|
||||
tool_calls.push_back(tc);
|
||||
}
|
||||
}
|
||||
messages.push_back(jmsg);
|
||||
}
|
||||
return messages;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
|
||||
return common_chat_msgs_parse_oaicompat(json::parse(messages));
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
|
||||
std::vector<common_chat_tool> result;
|
||||
|
||||
try {
|
||||
if (!tools.is_null()) {
|
||||
if (!tools.is_array()) {
|
||||
throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump());
|
||||
}
|
||||
for (const auto & tool : tools) {
|
||||
if (!tool.contains("type")) {
|
||||
throw std::runtime_error("Missing tool type: " + tool.dump());
|
||||
}
|
||||
const auto & type = tool.at("type");
|
||||
if (!type.is_string() || type != "function") {
|
||||
throw std::runtime_error("Unsupported tool type: " + tool.dump());
|
||||
}
|
||||
if (!tool.contains("function")) {
|
||||
throw std::runtime_error("Missing tool function: " + tool.dump());
|
||||
}
|
||||
|
||||
const auto & function = tool.at("function");
|
||||
result.push_back({
|
||||
/* .name = */ function.at("name"),
|
||||
/* .description = */ function.at("description"),
|
||||
/* .parameters = */ function.at("parameters").dump(),
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
|
||||
return common_chat_tools_parse_oaicompat(json::parse(tools));
|
||||
}
|
||||
|
||||
template <>
|
||||
json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
|
||||
if (tools.empty()) {
|
||||
return json();
|
||||
}
|
||||
|
||||
auto result = json::array();
|
||||
for (const auto & tool : tools) {
|
||||
result.push_back({
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", tool.name},
|
||||
{"description", tool.description},
|
||||
{"parameters", json::parse(tool.parameters)},
|
||||
}},
|
||||
});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
try {
|
||||
common_chat_msg msg;
|
||||
msg.role = "user";
|
||||
msg.content = "test";
|
||||
|
||||
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl);
|
||||
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = {msg};
|
||||
|
||||
common_chat_templates_apply(tmpls.get(), inputs);
|
||||
return true;
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
llama_chat_message chat[] = {{"user", "test"}};
|
||||
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
|
||||
return res >= 0;
|
||||
}
|
||||
|
||||
std::string common_chat_format_single(
|
||||
const struct common_chat_templates * tmpls,
|
||||
const std::vector<common_chat_msg> & past_msg,
|
||||
const common_chat_msg & new_msg,
|
||||
bool add_ass,
|
||||
bool use_jinja) {
|
||||
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.use_jinja = use_jinja;
|
||||
|
||||
std::string fmt_past_msg;
|
||||
if (!past_msg.empty()) {
|
||||
inputs.messages = past_msg;
|
||||
inputs.add_generation_prompt = false;
|
||||
fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt;
|
||||
}
|
||||
std::ostringstream ss;
|
||||
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
||||
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
||||
ss << "\n";
|
||||
};
|
||||
// format chat with new_msg
|
||||
inputs.messages.push_back(new_msg);
|
||||
inputs.add_generation_prompt = add_ass;
|
||||
auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt;
|
||||
// get the diff part
|
||||
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.use_jinja = use_jinja;
|
||||
auto add_simple_msg = [&](auto role, auto content) {
|
||||
common_chat_msg msg;
|
||||
msg.role = role;
|
||||
msg.content = content;
|
||||
inputs.messages.push_back(msg);
|
||||
};
|
||||
add_simple_msg("system", "You are a helpful assistant");
|
||||
add_simple_msg("user", "Hello");
|
||||
add_simple_msg("assistant", "Hi there");
|
||||
add_simple_msg("user", "How are you?");
|
||||
return common_chat_templates_apply(tmpls, inputs).prompt;
|
||||
}
|
||||
|
||||
#define CHATML_TEMPLATE_SRC \
|
||||
"{%- for message in messages -%}\n" \
|
||||
" {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
|
||||
"{%- endfor -%}\n" \
|
||||
"{%- if add_generation_prompt -%}\n" \
|
||||
" {{- '<|im_start|>assistant\n' -}}\n" \
|
||||
"{%- endif -%}"
|
||||
|
||||
void common_chat_templates_free(struct common_chat_templates * tmpls) {
|
||||
delete tmpls;
|
||||
}
|
||||
|
||||
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) {
|
||||
return tmpls->has_explicit_template;
|
||||
}
|
||||
|
||||
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
|
||||
if (variant != nullptr) {
|
||||
if (strcmp(variant, "tool_use") == 0) {
|
||||
if (tmpls->template_tool_use) {
|
||||
return tmpls->template_tool_use->source().c_str();
|
||||
}
|
||||
return nullptr;
|
||||
} else {
|
||||
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
|
||||
}
|
||||
}
|
||||
return tmpls->template_default->source().c_str();
|
||||
}
|
||||
|
||||
common_chat_templates_ptr common_chat_templates_init(
|
||||
const struct llama_model * model,
|
||||
const std::string & chat_template_override,
|
||||
const std::string & bos_token_override,
|
||||
const std::string & eos_token_override)
|
||||
{
|
||||
std::string default_template_src;
|
||||
std::string template_tool_use_src;
|
||||
|
||||
bool has_explicit_template = !chat_template_override.empty();
|
||||
if (chat_template_override.empty()) {
|
||||
GGML_ASSERT(model != nullptr);
|
||||
const auto * str = llama_model_chat_template(model, /* name */ nullptr);
|
||||
if (str) {
|
||||
default_template_src = str;
|
||||
has_explicit_template = true;
|
||||
}
|
||||
str = llama_model_chat_template(model, /* name */ "tool_use");
|
||||
if (str) {
|
||||
template_tool_use_src = str;
|
||||
has_explicit_template = true;
|
||||
}
|
||||
} else {
|
||||
default_template_src = chat_template_override;
|
||||
}
|
||||
if (default_template_src.empty() || default_template_src == "chatml") {
|
||||
if (!template_tool_use_src.empty()) {
|
||||
default_template_src = template_tool_use_src;
|
||||
} else {
|
||||
default_template_src = CHATML_TEMPLATE_SRC;
|
||||
}
|
||||
}
|
||||
std::string token_bos = bos_token_override;
|
||||
std::string token_eos = eos_token_override;
|
||||
if (model) {
|
||||
const auto * vocab = llama_model_get_vocab(model);
|
||||
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
|
||||
if (token == LLAMA_TOKEN_NULL) {
|
||||
if (default_template_src.find(jinja_variable_name) != std::string::npos
|
||||
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
|
||||
LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name);
|
||||
}
|
||||
return std::string();
|
||||
}
|
||||
return common_token_to_piece(vocab, token, true);
|
||||
};
|
||||
token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
|
||||
token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
|
||||
}
|
||||
common_chat_templates_ptr tmpls(new common_chat_templates());
|
||||
tmpls->has_explicit_template = has_explicit_template;
|
||||
try {
|
||||
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
|
||||
tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
|
||||
}
|
||||
if (!template_tool_use_src.empty()) {
|
||||
try {
|
||||
tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
|
||||
}
|
||||
}
|
||||
return tmpls;
|
||||
}
|
||||
|
||||
std::string common_chat_format_name(common_chat_format format) {
|
||||
switch (format) {
|
||||
@@ -38,22 +463,22 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
|
||||
|
||||
json_error_locator() : position(0), found_error(false) {}
|
||||
|
||||
bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
|
||||
bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT
|
||||
this->position = position - 1;
|
||||
this->found_error = true;
|
||||
return false;
|
||||
}
|
||||
bool null() override { return true; }
|
||||
bool boolean(bool) override { return true; }
|
||||
bool number_integer(number_integer_t) override { return true; }
|
||||
bool number_unsigned(number_unsigned_t) override { return true; }
|
||||
bool number_float(number_float_t, const string_t &) override { return true; }
|
||||
bool string(string_t &) override { return true; }
|
||||
bool binary(binary_t &) override { return true; }
|
||||
bool start_object(std::size_t) override { return true; }
|
||||
bool key(string_t &) override { return true; }
|
||||
bool null() override { return true; } // NOLINT
|
||||
bool boolean(bool) override { return true; } // NOLINT
|
||||
bool number_integer(number_integer_t) override { return true; } // NOLINT
|
||||
bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT
|
||||
bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT
|
||||
bool string(string_t &) override { return true; } // NOLINT
|
||||
bool binary(binary_t &) override { return true; } // NOLINT
|
||||
bool start_object(std::size_t) override { return true; } // NOLINT
|
||||
bool key(string_t &) override { return true; } // NOLINT
|
||||
bool end_object() override { return true; }
|
||||
bool start_array(std::size_t) override { return true; }
|
||||
bool start_array(std::size_t) override { return true; } // NOLINT
|
||||
bool end_array() override { return true; }
|
||||
};
|
||||
json_error_locator err_loc;
|
||||
@@ -187,13 +612,20 @@ static std::string apply(
|
||||
// tmpl_inputs.now = std::chrono::system_clock::now();
|
||||
|
||||
minja::chat_template_options tmpl_opts;
|
||||
tmpl_opts.use_bos_token = false;
|
||||
tmpl_opts.use_eos_token = false;
|
||||
|
||||
return tmpl.apply(tmpl_inputs, tmpl_opts);
|
||||
// To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
|
||||
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
|
||||
// may be needed inside the template / between messages too.
|
||||
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
|
||||
if (string_starts_with(result, tmpl.bos_token())) {
|
||||
result = result.substr(tmpl.bos_token().size());
|
||||
}
|
||||
if (string_ends_with(result, tmpl.eos_token())) {
|
||||
result = result.substr(0, result.size() - tmpl.eos_token().size());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
auto tool_call_schemas = json::array();
|
||||
@@ -247,7 +679,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
|
||||
{"required", json::array({"tool_call"})},
|
||||
};
|
||||
const auto schema =
|
||||
inputs.tool_choice != "required"
|
||||
inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED
|
||||
? json {
|
||||
{"anyOf", json::array({
|
||||
tool_call,
|
||||
@@ -303,9 +735,9 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) {
|
||||
return result;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
@@ -348,9 +780,9 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
|
||||
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
@@ -455,10 +887,10 @@ static void expect_tool_parameters(const std::string & name, const json & parame
|
||||
const auto & parameters_required = parameters.at("required");
|
||||
for (const auto & prop : expected_properties) {
|
||||
if (!parameters_properties.contains(prop)) {
|
||||
throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop);
|
||||
throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT
|
||||
}
|
||||
if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
|
||||
throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop);
|
||||
throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT
|
||||
}
|
||||
}
|
||||
if (parameters_properties.size() != expected_properties.size()) {
|
||||
@@ -466,18 +898,16 @@ static void expect_tool_parameters(const std::string & name, const json & parame
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) {
|
||||
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
|
||||
auto builtin_tools = json::array();
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
|
||||
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
|
||||
if (name == "wolfram_alpha") {
|
||||
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
|
||||
expect_tool_parameters(name, parameters, {"query"});
|
||||
} else if (name == "web_search" || name == "brave_search") {
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
|
||||
expect_tool_parameters(name, parameters, {"query"});
|
||||
} else if (name == "python" || name == "code_interpreter") {
|
||||
@@ -489,7 +919,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
||||
|
||||
std::vector<std::string> kvs;
|
||||
for (const auto & [key, value] : parameters.at("properties").items()) {
|
||||
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value));
|
||||
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
|
||||
}
|
||||
|
||||
tool_rules.push_back(
|
||||
@@ -560,34 +990,33 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo
|
||||
auto arg_value_str = raw_args.substr(it_eq + 1);
|
||||
auto arg_value = json::parse(arg_value_str);
|
||||
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ match.prefix().str(),
|
||||
/* .tool_calls = */ {
|
||||
{
|
||||
/* .name = */ match[1],
|
||||
/* .arguments = */ (json {
|
||||
{arg_name, arg_value},
|
||||
}).dump(),
|
||||
/* .id = */ "",
|
||||
},
|
||||
},
|
||||
};
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
msg.content = match.prefix().str();
|
||||
msg.tool_calls.push_back({
|
||||
/* .name = */ name,
|
||||
/* .arguments = */ (json {
|
||||
{arg_name, arg_value},
|
||||
}).dump(),
|
||||
/* .id = */ "",
|
||||
});
|
||||
return msg;
|
||||
}
|
||||
}
|
||||
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != "required" && inputs.json_schema.is_null();
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
tool_rules.push_back(builder.add_rule(name + "-call",
|
||||
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
|
||||
@@ -666,15 +1095,15 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input,
|
||||
return msg;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
fprintf(stderr, "%s\n", __func__);
|
||||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
LOG_DBG("%s\n", __func__);
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
||||
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||
});
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
@@ -712,14 +1141,14 @@ static common_chat_msg common_chat_parse_firefunction_v2(const std::string & inp
|
||||
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> first_tool_rules;
|
||||
std::vector<std::string> subsequent_tool_rules;
|
||||
@@ -727,6 +1156,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
||||
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
|
||||
@@ -795,14 +1225,14 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||
common_chat_params data;
|
||||
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
|
||||
std::string python_code_argument_name;
|
||||
auto has_raw_python = false;
|
||||
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
@@ -814,7 +1244,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
|
||||
throw std::runtime_error("Missing type in python tool");
|
||||
}
|
||||
has_raw_python = true;
|
||||
auto type = parameters.at("type");
|
||||
const auto & type = parameters.at("type");
|
||||
if (type == "object") {
|
||||
auto properties = parameters.at("properties");
|
||||
for (auto it = properties.begin(); it != properties.end(); ++it) {
|
||||
@@ -854,17 +1284,15 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
|
||||
std::smatch match;
|
||||
if (std::regex_search(input, match, python_tag_regex)) {
|
||||
auto code = match[1].str();
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ match.prefix().str(),
|
||||
/* .tool_calls = */ {
|
||||
{
|
||||
/* .name = */ "python",
|
||||
/* .arguments = */ (json {{"code", code}}).dump(),
|
||||
/* .id = */ "",
|
||||
},
|
||||
}
|
||||
};
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
msg.content = match.prefix().str();
|
||||
msg.tool_calls.push_back({
|
||||
/* .name = */ "python",
|
||||
/* .arguments = */ (json {{"code", code}}).dump(),
|
||||
/* .id = */ "",
|
||||
});
|
||||
return msg;
|
||||
}
|
||||
static std::regex function_regex(R"(<function=(\w+)>)");
|
||||
static std::regex close_regex(R"(</function>)");
|
||||
@@ -872,10 +1300,10 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
|
||||
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
@@ -908,20 +1336,18 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
|
||||
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
|
||||
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
|
||||
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
|
||||
auto end = input.end();
|
||||
std::sregex_iterator rend;
|
||||
std::sregex_iterator rit(input.begin(), end, start_pattern);
|
||||
if (rit == rend) {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ input,
|
||||
/* .tool_calls = */ {},
|
||||
};
|
||||
msg.content = input;
|
||||
return msg;
|
||||
}
|
||||
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
result.content = rit->prefix();
|
||||
msg.content = rit->prefix();
|
||||
|
||||
auto it = rit->suffix().first;
|
||||
while (it != end) {
|
||||
@@ -930,7 +1356,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
|
||||
throw std::runtime_error("Failed to parse json tool call");
|
||||
}
|
||||
const auto & arguments = call.at("arguments");
|
||||
result.tool_calls.push_back({
|
||||
msg.tool_calls.push_back({
|
||||
call.at("name"),
|
||||
arguments.dump(),
|
||||
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
@@ -947,17 +1373,17 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
|
||||
break;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
return msg;
|
||||
} catch (const std::exception & e) {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ input,
|
||||
/* .tool_calls = */ {},
|
||||
};
|
||||
LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
msg.content = input;
|
||||
return msg;
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
@@ -973,12 +1399,35 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
|
||||
return data;
|
||||
}
|
||||
|
||||
common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
static common_chat_params common_chat_templates_apply_jinja(
|
||||
const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs)
|
||||
{
|
||||
templates_params params;
|
||||
params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
|
||||
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
|
||||
? *tmpls->template_tool_use
|
||||
: *tmpls->template_default;
|
||||
const auto & src = tmpl.source();
|
||||
const auto & caps = tmpl.original_caps();
|
||||
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
params.extract_reasoning = inputs.extract_reasoning;
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.grammar = inputs.grammar;
|
||||
if (!inputs.json_schema.empty()) {
|
||||
params.json_schema = json::parse(inputs.json_schema);
|
||||
}
|
||||
|
||||
if (inputs.tools.is_array()) {
|
||||
if (inputs.tool_choice != "none" && !inputs.grammar.empty()) {
|
||||
if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
|
||||
LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
|
||||
params.parallel_tool_calls = false;
|
||||
} else {
|
||||
params.parallel_tool_calls = inputs.parallel_tool_calls;
|
||||
}
|
||||
|
||||
if (params.tools.is_array()) {
|
||||
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
|
||||
throw std::runtime_error("Cannot specify grammar with tools");
|
||||
}
|
||||
if (caps.supports_tool_calls && !caps.supports_tools) {
|
||||
@@ -987,68 +1436,135 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
|
||||
}
|
||||
|
||||
// DeepSeek R1: use handler in all cases except json schema (thinking / tools).
|
||||
if (src.find("<|tool▁calls▁begin|>") != std::string::npos && inputs.json_schema.is_null()) {
|
||||
return common_chat_params_init_deepseek_r1(tmpl, inputs);
|
||||
if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) {
|
||||
return common_chat_params_init_deepseek_r1(tmpl, params);
|
||||
}
|
||||
|
||||
// Command R7B: : use handler in all cases except json schema (thinking / tools).
|
||||
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) {
|
||||
return common_chat_params_init_command_r7b(tmpl, inputs);
|
||||
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) {
|
||||
return common_chat_params_init_command_r7b(tmpl, params);
|
||||
}
|
||||
|
||||
// Use generic handler when mixing tools + JSON schema.
|
||||
// TODO: support that mix in handlers below.
|
||||
if ((!inputs.tools.is_array() && inputs.json_schema.is_object())) {
|
||||
return common_chat_params_init_generic(tmpl, inputs);
|
||||
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
||||
return common_chat_params_init_generic(tmpl, params);
|
||||
}
|
||||
|
||||
// Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
|
||||
if (src.find(">>>all") != std::string::npos) {
|
||||
return common_chat_params_init_functionary_v3_2(tmpl, inputs);
|
||||
return common_chat_params_init_functionary_v3_2(tmpl, params);
|
||||
}
|
||||
|
||||
// Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
|
||||
if (src.find(" functools[") != std::string::npos) {
|
||||
return common_chat_params_init_firefunction_v2(tmpl, inputs);
|
||||
return common_chat_params_init_firefunction_v2(tmpl, params);
|
||||
}
|
||||
|
||||
// Plain handler (no tools)
|
||||
if (inputs.tools.is_null() || inputs.tool_choice == "none") {
|
||||
return common_chat_params_init_without_tools(tmpl, inputs);
|
||||
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return common_chat_params_init_without_tools(tmpl, params);
|
||||
}
|
||||
|
||||
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
|
||||
if (src.find("<tool_call>") != std::string::npos) {
|
||||
return common_chat_params_init_hermes_2_pro(tmpl, inputs);
|
||||
return common_chat_params_init_hermes_2_pro(tmpl, params);
|
||||
}
|
||||
|
||||
// Functionary v3.1 (w/ tools)
|
||||
if (src.find("<|start_header_id|>") != std::string::npos
|
||||
&& src.find("<function=") != std::string::npos) {
|
||||
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs);
|
||||
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
|
||||
}
|
||||
|
||||
// Llama 3.1, 3.2, 3.3 (w/ tools)
|
||||
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
||||
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
|
||||
return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
|
||||
}
|
||||
|
||||
// Mistral Nemo (w/ tools)
|
||||
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
||||
return common_chat_params_init_mistral_nemo(tmpl, inputs);
|
||||
return common_chat_params_init_mistral_nemo(tmpl, params);
|
||||
}
|
||||
|
||||
// Generic fallback
|
||||
return common_chat_params_init_generic(tmpl, inputs);
|
||||
return common_chat_params_init_generic(tmpl, params);
|
||||
}
|
||||
|
||||
// Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template.
|
||||
static common_chat_params common_chat_templates_apply_legacy(
|
||||
const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs)
|
||||
{
|
||||
int alloc_size = 0;
|
||||
std::vector<llama_chat_message> chat;
|
||||
std::vector<std::string> contents;
|
||||
for (const auto & msg : inputs.messages) {
|
||||
auto content = msg.content;
|
||||
for (const auto & part : msg.content_parts) {
|
||||
if (part.type != "text") {
|
||||
LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
|
||||
continue;
|
||||
}
|
||||
if (!content.empty()) {
|
||||
content += "\n";;
|
||||
}
|
||||
content += part.text;
|
||||
}
|
||||
contents.emplace_back(std::move(content));
|
||||
}
|
||||
for (size_t i = 0; i < contents.size(); ++i) {
|
||||
const auto & msg = inputs.messages[i];
|
||||
const auto & content = contents[i];
|
||||
chat.push_back({msg.role.c_str(), content.c_str()});
|
||||
alloc_size += (msg.role.size() + content.size()) * 1.25;
|
||||
}
|
||||
|
||||
std::vector<char> buf(alloc_size);
|
||||
|
||||
// run the first time to get the total output length
|
||||
const auto & src = tmpls->template_default->source();
|
||||
int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
|
||||
|
||||
// error: chat template is not supported
|
||||
if (res < 0) {
|
||||
// if the custom "tmpl" is not supported, we throw an error
|
||||
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
||||
throw std::runtime_error("this custom template is not supported");
|
||||
}
|
||||
|
||||
// if it turns out that our buffer is too small, we resize it
|
||||
if ((size_t) res > buf.size()) {
|
||||
buf.resize(res);
|
||||
res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
|
||||
}
|
||||
|
||||
common_chat_params params;
|
||||
params.prompt = std::string(buf.data(), res);
|
||||
if (!inputs.json_schema.empty()) {
|
||||
params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema));
|
||||
} else {
|
||||
params.grammar = inputs.grammar;
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
common_chat_params common_chat_templates_apply(
|
||||
const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs)
|
||||
{
|
||||
GGML_ASSERT(tmpls != nullptr);
|
||||
return inputs.use_jinja
|
||||
? common_chat_templates_apply_jinja(tmpls, inputs)
|
||||
: common_chat_templates_apply_legacy(tmpls, inputs);
|
||||
}
|
||||
|
||||
static common_chat_msg common_chat_parse_content_only(const std::string & input) {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ input,
|
||||
/* .tool_calls = */ {},
|
||||
};
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
msg.content = input;
|
||||
return msg;
|
||||
}
|
||||
|
||||
common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {
|
||||
|
||||
134
common/chat.h
Normal file
134
common/chat.h
Normal file
@@ -0,0 +1,134 @@
|
||||
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
struct common_chat_templates;
|
||||
|
||||
struct common_chat_tool_call {
|
||||
std::string name;
|
||||
std::string arguments;
|
||||
std::string id;
|
||||
};
|
||||
|
||||
struct common_chat_msg_content_part {
|
||||
std::string type;
|
||||
std::string text;
|
||||
};
|
||||
|
||||
struct common_chat_msg {
|
||||
std::string role;
|
||||
std::string content;
|
||||
std::vector<common_chat_msg_content_part> content_parts = {};
|
||||
std::vector<common_chat_tool_call> tool_calls = {};
|
||||
std::string reasoning_content;
|
||||
std::string tool_name;
|
||||
std::string tool_call_id;
|
||||
};
|
||||
|
||||
struct common_chat_tool {
|
||||
std::string name;
|
||||
std::string description;
|
||||
std::string parameters;
|
||||
};
|
||||
|
||||
enum common_chat_tool_choice {
|
||||
COMMON_CHAT_TOOL_CHOICE_AUTO,
|
||||
COMMON_CHAT_TOOL_CHOICE_REQUIRED,
|
||||
COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||
};
|
||||
|
||||
enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
COMMON_CHAT_FORMAT_GENERIC,
|
||||
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
|
||||
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
||||
struct common_chat_templates_inputs {
|
||||
std::vector<common_chat_msg> messages;
|
||||
std::string grammar;
|
||||
std::string json_schema;
|
||||
bool add_generation_prompt = true;
|
||||
bool use_jinja = true;
|
||||
// Parameters below only supported when use_jinja is true
|
||||
std::vector<common_chat_tool> tools;
|
||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
bool parallel_tool_calls = false;
|
||||
bool extract_reasoning = true;
|
||||
};
|
||||
|
||||
struct common_chat_params {
|
||||
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
std::string prompt;
|
||||
std::string grammar;
|
||||
bool grammar_lazy = false;
|
||||
std::vector<common_grammar_trigger> grammar_triggers;
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
};
|
||||
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
||||
|
||||
void common_chat_templates_free(struct common_chat_templates * tmpls);
|
||||
|
||||
struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
|
||||
|
||||
typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
|
||||
|
||||
common_chat_templates_ptr common_chat_templates_init(
|
||||
const struct llama_model * model,
|
||||
const std::string & chat_template_override,
|
||||
const std::string & bos_token_override = "",
|
||||
const std::string & eos_token_override = "");
|
||||
|
||||
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
||||
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
|
||||
|
||||
|
||||
struct common_chat_params common_chat_templates_apply(
|
||||
const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs);
|
||||
|
||||
// Format single message, while taking into account the position of that message in chat history
|
||||
std::string common_chat_format_single(
|
||||
const struct common_chat_templates * tmpls,
|
||||
const std::vector<common_chat_msg> & past_msg,
|
||||
const common_chat_msg & new_msg,
|
||||
bool add_ass,
|
||||
bool use_jinja);
|
||||
|
||||
// Returns an example of formatted chat
|
||||
std::string common_chat_format_example(
|
||||
const struct common_chat_templates * tmpls,
|
||||
bool use_jinja);
|
||||
|
||||
std::string common_chat_format_name(common_chat_format format);
|
||||
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
|
||||
|
||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
||||
|
||||
// Parses a JSON array of messages in OpenAI's chat completion API format.
|
||||
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
|
||||
template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
|
||||
|
||||
// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
|
||||
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
|
||||
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
||||
@@ -1,55 +0,0 @@
|
||||
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include <json.hpp>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
struct common_chat_inputs {
|
||||
json messages;
|
||||
json tools;
|
||||
json tool_choice;
|
||||
json json_schema;
|
||||
bool parallel_tool_calls;
|
||||
bool stream;
|
||||
std::string grammar;
|
||||
bool add_generation_prompt = true;
|
||||
bool extract_reasoning = true;
|
||||
};
|
||||
|
||||
enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
COMMON_CHAT_FORMAT_GENERIC,
|
||||
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
|
||||
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
||||
struct common_chat_params {
|
||||
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
json prompt;
|
||||
std::string grammar;
|
||||
bool grammar_lazy = false;
|
||||
std::vector<common_grammar_trigger> grammar_triggers;
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
};
|
||||
|
||||
struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params);
|
||||
std::string common_chat_format_name(common_chat_format format);
|
||||
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
|
||||
@@ -12,8 +12,6 @@
|
||||
#include "json.hpp"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "llama.h"
|
||||
#include "chat.hpp"
|
||||
#include "chat-template.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
@@ -1768,174 +1766,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
|
||||
return text;
|
||||
}
|
||||
|
||||
//
|
||||
// Chat template utils
|
||||
//
|
||||
|
||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
try {
|
||||
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
|
||||
common_chat_inputs inputs;
|
||||
inputs.messages = json::array({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}});
|
||||
common_chat_params_init(chat_template, inputs);
|
||||
return true;
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
llama_chat_message chat[] = {{"user", "test"}};
|
||||
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
|
||||
return res >= 0;
|
||||
}
|
||||
|
||||
std::string common_chat_apply_template(
|
||||
const common_chat_template & tmpl,
|
||||
const std::vector<common_chat_msg> & msgs,
|
||||
bool add_ass,
|
||||
bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
auto messages = json::array();
|
||||
for (const auto & msg : msgs) {
|
||||
messages.push_back({{"role", msg.role}, {"content", msg.content}});
|
||||
}
|
||||
common_chat_inputs inputs;
|
||||
inputs.messages = messages;
|
||||
inputs.add_generation_prompt = add_ass;
|
||||
return common_chat_params_init(tmpl, inputs).prompt;
|
||||
}
|
||||
|
||||
int alloc_size = 0;
|
||||
std::vector<llama_chat_message> chat;
|
||||
for (const auto & msg : msgs) {
|
||||
chat.push_back({msg.role.c_str(), msg.content.c_str()});
|
||||
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
|
||||
}
|
||||
|
||||
std::vector<char> buf(alloc_size);
|
||||
|
||||
// run the first time to get the total output length
|
||||
int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
|
||||
// error: chat template is not supported
|
||||
if (res < 0) {
|
||||
// if the custom "tmpl" is not supported, we throw an error
|
||||
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
||||
throw std::runtime_error("this custom template is not supported");
|
||||
}
|
||||
|
||||
// if it turns out that our buffer is too small, we resize it
|
||||
if ((size_t) res > buf.size()) {
|
||||
buf.resize(res);
|
||||
res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
}
|
||||
|
||||
std::string formatted_chat(buf.data(), res);
|
||||
return formatted_chat;
|
||||
}
|
||||
|
||||
std::string common_chat_format_single(
|
||||
const common_chat_template & tmpl,
|
||||
const std::vector<common_chat_msg> & past_msg,
|
||||
const common_chat_msg & new_msg,
|
||||
bool add_ass,
|
||||
bool use_jinja) {
|
||||
std::ostringstream ss;
|
||||
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja);
|
||||
std::vector<common_chat_msg> chat_new(past_msg);
|
||||
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
||||
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
||||
ss << "\n";
|
||||
};
|
||||
// format chat with new_msg
|
||||
chat_new.push_back(new_msg);
|
||||
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja);
|
||||
// get the diff part
|
||||
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) {
|
||||
std::vector<common_chat_msg> msgs = {
|
||||
{"system", "You are a helpful assistant", {}},
|
||||
{"user", "Hello", {}},
|
||||
{"assistant", "Hi there", {}},
|
||||
{"user", "How are you?", {}},
|
||||
};
|
||||
return common_chat_apply_template(tmpl, msgs, true, use_jinja);
|
||||
}
|
||||
|
||||
#define CHATML_TEMPLATE_SRC \
|
||||
"{%- for message in messages -%}\n" \
|
||||
" {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
|
||||
"{%- endfor -%}\n" \
|
||||
"{%- if add_generation_prompt -%}\n" \
|
||||
" {{- '<|im_start|>assistant\n' -}}\n" \
|
||||
"{%- endif -%}"
|
||||
|
||||
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
|
||||
{
|
||||
std::string default_template_src;
|
||||
std::string template_tool_use_src;
|
||||
|
||||
bool has_explicit_template = !chat_template_override.empty();
|
||||
if (chat_template_override.empty()) {
|
||||
auto str = llama_model_chat_template(model, /* name */ nullptr);
|
||||
if (str) {
|
||||
default_template_src = str;
|
||||
has_explicit_template = true;
|
||||
}
|
||||
str = llama_model_chat_template(model, /* name */ "tool_use");
|
||||
if (str) {
|
||||
template_tool_use_src = str;
|
||||
has_explicit_template = true;
|
||||
}
|
||||
} else {
|
||||
default_template_src = chat_template_override;
|
||||
}
|
||||
if (default_template_src.empty() || default_template_src == "chatml") {
|
||||
if (!template_tool_use_src.empty()) {
|
||||
default_template_src = template_tool_use_src;
|
||||
} else {
|
||||
default_template_src = CHATML_TEMPLATE_SRC;
|
||||
}
|
||||
}
|
||||
auto vocab = llama_model_get_vocab(model);
|
||||
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
|
||||
if (token == LLAMA_TOKEN_NULL) {
|
||||
if (default_template_src.find(jinja_variable_name) != std::string::npos
|
||||
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
|
||||
LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
|
||||
}
|
||||
return std::string();
|
||||
} else {
|
||||
return common_token_to_piece(vocab, token, true);
|
||||
}
|
||||
};
|
||||
auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
|
||||
auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
|
||||
try {
|
||||
return {
|
||||
has_explicit_template,
|
||||
std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
|
||||
template_tool_use_src.empty()
|
||||
? nullptr
|
||||
: std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos),
|
||||
};
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to parse chat template: %s\n", __func__, e.what());
|
||||
return {
|
||||
has_explicit_template,
|
||||
std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos),
|
||||
nullptr,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// KV cache utils
|
||||
//
|
||||
|
||||
@@ -178,10 +178,10 @@ struct common_params_speculative {
|
||||
|
||||
int32_t n_ctx = 0; // draft context size
|
||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
|
||||
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.9f; // minimum speculative decoding probability (greedy)
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||
|
||||
struct cpu_params cpuparams;
|
||||
struct cpu_params cpuparams_batch;
|
||||
@@ -616,62 +616,6 @@ std::string common_detokenize(
|
||||
const std::vector<llama_token> & tokens,
|
||||
bool special = true);
|
||||
|
||||
//
|
||||
// Chat template utils
|
||||
//
|
||||
|
||||
struct common_tool_call {
|
||||
std::string name;
|
||||
std::string arguments;
|
||||
std::string id;
|
||||
};
|
||||
|
||||
// same with llama_chat_message, but uses std::string
|
||||
struct common_chat_msg {
|
||||
std::string role;
|
||||
std::string content;
|
||||
std::vector<common_tool_call> tool_calls;
|
||||
std::string reasoning_content = "";
|
||||
};
|
||||
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
||||
|
||||
namespace minja {
|
||||
class chat_template;
|
||||
}
|
||||
|
||||
typedef minja::chat_template common_chat_template;
|
||||
|
||||
struct common_chat_templates {
|
||||
bool has_explicit_template; // Model had builtin template or template overridde was specified.
|
||||
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
|
||||
std::unique_ptr<common_chat_template> template_tool_use;
|
||||
};
|
||||
|
||||
// CPP wrapper for llama_chat_apply_template
|
||||
// If the built-in template is not supported, we default to chatml
|
||||
// If the custom "tmpl" is not supported, we throw an error
|
||||
std::string common_chat_apply_template(
|
||||
const common_chat_template & tmpl,
|
||||
const std::vector<common_chat_msg> & chat,
|
||||
bool add_ass,
|
||||
bool use_jinja);
|
||||
|
||||
// Format single message, while taking into account the position of that message in chat history
|
||||
std::string common_chat_format_single(
|
||||
const common_chat_template & tmpl,
|
||||
const std::vector<common_chat_msg> & past_msg,
|
||||
const common_chat_msg & new_msg,
|
||||
bool add_ass,
|
||||
bool use_jinja);
|
||||
|
||||
// Returns an example of formatted chat
|
||||
std::string common_chat_format_example(
|
||||
const common_chat_template & tmpl, bool use_jinja);
|
||||
|
||||
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
|
||||
|
||||
//
|
||||
// KV cache utils
|
||||
//
|
||||
|
||||
@@ -252,11 +252,6 @@ llama_tokens common_speculative_gen_draft(
|
||||
// add drafted token for each sequence
|
||||
const llama_token id = cur_p->data[0].id;
|
||||
|
||||
// only collect very high-confidence draft tokens
|
||||
if (cur_p->data[0].p < params.p_min) {
|
||||
break;
|
||||
}
|
||||
|
||||
common_sampler_accept(smpl, id, true);
|
||||
|
||||
result.push_back(id);
|
||||
@@ -265,6 +260,11 @@ llama_tokens common_speculative_gen_draft(
|
||||
break;
|
||||
}
|
||||
|
||||
// only collect very high-confidence draft tokens
|
||||
if (cur_p->data[0].p < params.p_min) {
|
||||
break;
|
||||
}
|
||||
|
||||
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
|
||||
|
||||
// evaluate the drafted tokens on the draft model
|
||||
|
||||
@@ -9,7 +9,7 @@ struct common_speculative_params {
|
||||
int n_draft = 16; // max drafted tokens
|
||||
int n_reuse = 256;
|
||||
|
||||
float p_min = 0.9f; // min probability required to accept a token in the draft
|
||||
float p_min = 0.75f; // min probability required to accept a token in the draft
|
||||
};
|
||||
|
||||
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
|
||||
|
||||
@@ -42,6 +42,16 @@ The following release is verified with good quality:
|
||||
|
||||
## News
|
||||
|
||||
- 2025.2
|
||||
- Optimize MUL_MAT Q4_0 on Intel GPU for all dGPUs and built-in GPUs since MTL. Increase the performance of LLM (llama-2-7b.Q4_0.gguf) 21%-87% on Intel GPUs (MTL, ARL-H, Arc, Flex, PVC).
|
||||
|GPU|Base tokens/s|Increased tokens/s|Percent|
|
||||
|-|-|-|-|
|
||||
|PVC 1550|39|73|+87%|
|
||||
|Flex 170|39|50|+28%|
|
||||
|Arc770|42|55|+30%|
|
||||
|MTL|13|16|+23%|
|
||||
|ARL-H|14|17|+21%|
|
||||
|
||||
- 2024.11
|
||||
- Use syclcompat to improve the performance on some platforms. This requires to use oneAPI 2025.0 or newer.
|
||||
|
||||
@@ -97,8 +107,8 @@ SYCL backend supports Intel GPU Family:
|
||||
| Intel Data Center Max Series | Support | Max 1550, 1100 |
|
||||
| Intel Data Center Flex Series | Support | Flex 170 |
|
||||
| Intel Arc Series | Support | Arc 770, 730M, Arc A750 |
|
||||
| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake |
|
||||
| Intel iGPU | Support | iGPU in 13700k, i5-1250P, i7-1260P, i7-1165G7 |
|
||||
| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake, Arrow Lake |
|
||||
| Intel iGPU | Support | iGPU in 13700k,iGPU in 13400, i5-1250P, i7-1260P, i7-1165G7 |
|
||||
|
||||
*Notes:*
|
||||
|
||||
@@ -660,8 +670,10 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
||||
| Name | Value | Function |
|
||||
|-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
|
||||
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
|
||||
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase |
|
||||
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
|
||||
|
||||
|
||||
## Known Issues
|
||||
|
||||
- `Split-mode:[row]` is not supported.
|
||||
|
||||
@@ -206,6 +206,14 @@ This provides GPU acceleration using the MUSA cores of your Moore Threads MTT GP
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
For static build:
|
||||
|
||||
```bash
|
||||
cmake -B build -DGGML_MUSA=ON \
|
||||
-DBUILD_SHARED_LIBS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
The environment variable [`MUSA_VISIBLE_DEVICES`](https://docs.mthreads.com/musa-sdk/musa-sdk-doc-online/programming_guide/Z%E9%99%84%E5%BD%95/) can be used to specify which GPU(s) will be used.
|
||||
|
||||
The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enable unified memory in Linux. This allows swapping to system RAM instead of crashing when the GPU VRAM is exhausted.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# llama.cpp/examples/imatrix
|
||||
|
||||
Compute an importance matrix for a model and given text dataset. Can be used during quantization to enchance the quality of the quantized models.
|
||||
Compute an importance matrix for a model and given text dataset. Can be used during quantization to enhance the quality of the quantized models.
|
||||
More information is available here: https://github.com/ggml-org/llama.cpp/pull/4861
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -124,15 +124,26 @@ struct ContentView: View {
|
||||
}
|
||||
}
|
||||
}.sheet(isPresented: $showingHelp) { // Sheet for help modal
|
||||
VStack(alignment: .leading) {
|
||||
NavigationView {
|
||||
VStack(alignment: .leading) {
|
||||
Text("1. Make sure the model is in GGUF Format")
|
||||
.padding()
|
||||
Text("2. Copy the download link of the quantized model")
|
||||
.padding()
|
||||
VStack(alignment: .leading) {
|
||||
Text("1. Make sure the model is in GGUF Format")
|
||||
.padding()
|
||||
Text("2. Copy the download link of the quantized model")
|
||||
.padding()
|
||||
}
|
||||
Spacer()
|
||||
}
|
||||
Spacer()
|
||||
}
|
||||
.navigationTitle("Help")
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
.toolbar {
|
||||
ToolbarItem(placement: .navigationBarTrailing) {
|
||||
Button("Done") {
|
||||
showingHelp = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1729,11 +1729,11 @@ void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) {
|
||||
}
|
||||
}
|
||||
|
||||
static void build_clip_img_from_data(const stbi_uc * data, int nx, int ny, clip_image_u8 * img) {
|
||||
void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img) {
|
||||
img->nx = nx;
|
||||
img->ny = ny;
|
||||
img->buf.resize(3 * nx * ny);
|
||||
memcpy(img->buf.data(), data, img->buf.size());
|
||||
memcpy(img->buf.data(), rgb_pixels, img->buf.size());
|
||||
}
|
||||
|
||||
bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
|
||||
@@ -1743,7 +1743,7 @@ bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
|
||||
LOG_ERR("%s: failed to load image '%s'\n", __func__, fname);
|
||||
return false;
|
||||
}
|
||||
build_clip_img_from_data(data, nx, ny, img);
|
||||
clip_build_img_from_pixels(data, nx, ny, img);
|
||||
stbi_image_free(data);
|
||||
return true;
|
||||
}
|
||||
@@ -1755,7 +1755,7 @@ bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length
|
||||
LOG_ERR("%s: failed to decode image bytes\n", __func__);
|
||||
return false;
|
||||
}
|
||||
build_clip_img_from_data(data, nx, ny, img);
|
||||
clip_build_img_from_pixels(data, nx, ny, img);
|
||||
stbi_image_free(data);
|
||||
return true;
|
||||
}
|
||||
@@ -2712,9 +2712,13 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
|
||||
if (!ctx->has_glm_projector) {
|
||||
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
|
||||
// The patches vector is used to get rows to index into the embeds with;
|
||||
// we should skip dim 0 only if we have CLS to avoid going out of bounds
|
||||
// when retrieving the rows.
|
||||
int patch_offset = ctx->has_class_embedding ? 1 : 0;
|
||||
int* patches_data = (int*)malloc(ggml_nbytes(patches));
|
||||
for (int i = 0; i < num_patches; i++) {
|
||||
patches_data[i] = i + 1;
|
||||
patches_data[i] = i + patch_offset;
|
||||
}
|
||||
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
|
||||
free(patches_data);
|
||||
|
||||
@@ -73,6 +73,9 @@ CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
|
||||
CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);
|
||||
CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
|
||||
|
||||
/** build image from pixels decoded by other libraries instead of stb_image.h for better performance. The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes */
|
||||
CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img);
|
||||
|
||||
CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
|
||||
|
||||
/** interpret bytes as an image file with length bytes_length, and use the result to populate img */
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
#include "llama.h"
|
||||
#include "chat-template.hpp"
|
||||
#include "chat.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
@@ -158,7 +158,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
auto chat_templates = common_chat_templates_from_model(model, params.chat_template);
|
||||
auto chat_templates = common_chat_templates_init(model, params.chat_template);
|
||||
|
||||
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
|
||||
|
||||
@@ -201,7 +201,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// auto enable conversation mode if chat template is available
|
||||
const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.template_default;
|
||||
const bool has_chat_template = common_chat_templates_was_explicit(chat_templates.get());
|
||||
if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) {
|
||||
if (has_chat_template) {
|
||||
LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
|
||||
@@ -219,7 +219,7 @@ int main(int argc, char ** argv) {
|
||||
// print chat template example in conversation mode
|
||||
if (params.conversation_mode) {
|
||||
if (params.enable_chat_template) {
|
||||
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str());
|
||||
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.get(), params.use_jinja).c_str());
|
||||
} else {
|
||||
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
|
||||
}
|
||||
@@ -264,9 +264,11 @@ int main(int argc, char ** argv) {
|
||||
std::vector<llama_token> embd_inp;
|
||||
|
||||
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
|
||||
common_chat_msg new_msg{role, content, {}};
|
||||
auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja);
|
||||
chat_msgs.push_back({role, content, {}});
|
||||
common_chat_msg new_msg;
|
||||
new_msg.role = role;
|
||||
new_msg.content = content;
|
||||
auto formatted = common_chat_format_single(chat_templates.get(), chat_msgs, new_msg, role == "user", g_params->use_jinja);
|
||||
chat_msgs.push_back(new_msg);
|
||||
LOG_DBG("formatted: '%s'\n", formatted.c_str());
|
||||
return formatted;
|
||||
};
|
||||
@@ -755,11 +757,14 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// check for reverse prompt using special tokens
|
||||
llama_token last_token = common_sampler_last(smpl);
|
||||
if (std::find(antiprompt_token.begin(), antiprompt_token.end(), last_token) != antiprompt_token.end()) {
|
||||
if (params.interactive) {
|
||||
is_interacting = true;
|
||||
for (auto token : antiprompt_token) {
|
||||
if (token == last_token) {
|
||||
if (params.interactive) {
|
||||
is_interacting = true;
|
||||
}
|
||||
is_antiprompt = true;
|
||||
break;
|
||||
}
|
||||
is_antiprompt = true;
|
||||
}
|
||||
|
||||
if (is_antiprompt) {
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "chat-template.hpp"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "json.hpp"
|
||||
#include "linenoise.cpp/linenoise.h"
|
||||
@@ -113,6 +113,7 @@ class Opt {
|
||||
llama_context_params ctx_params;
|
||||
llama_model_params model_params;
|
||||
std::string model_;
|
||||
std::string chat_template_file;
|
||||
std::string user;
|
||||
bool use_jinja = false;
|
||||
int context_size = -1, ngl = -1;
|
||||
@@ -148,6 +149,16 @@ class Opt {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int handle_option_with_value(int argc, const char ** argv, int & i, std::string & option_value) {
|
||||
if (i + 1 >= argc) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
option_value = argv[++i];
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int parse(int argc, const char ** argv) {
|
||||
bool options_parsing = true;
|
||||
for (int i = 1, positional_args_i = 0; i < argc; ++i) {
|
||||
@@ -169,6 +180,11 @@ class Opt {
|
||||
verbose = true;
|
||||
} else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
|
||||
use_jinja = true;
|
||||
} else if (options_parsing && strcmp(argv[i], "--chat-template-file") == 0){
|
||||
if (handle_option_with_value(argc, argv, i, chat_template_file) == 1) {
|
||||
return 1;
|
||||
}
|
||||
use_jinja = true;
|
||||
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
|
||||
help = true;
|
||||
return 0;
|
||||
@@ -207,6 +223,11 @@ class Opt {
|
||||
"Options:\n"
|
||||
" -c, --context-size <value>\n"
|
||||
" Context size (default: %d)\n"
|
||||
" --chat-template-file <path>\n"
|
||||
" Path to the file containing the chat template to use with the model.\n"
|
||||
" Only supports jinja templates and implicitly sets the --jinja flag.\n"
|
||||
" --jinja\n"
|
||||
" Use jinja templating for the chat template of the model\n"
|
||||
" -n, -ngl, --ngl <value>\n"
|
||||
" Number of GPU layers (default: %d)\n"
|
||||
" --temp <value>\n"
|
||||
@@ -261,13 +282,12 @@ static int get_terminal_width() {
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef LLAMA_USE_CURL
|
||||
class File {
|
||||
public:
|
||||
FILE * file = nullptr;
|
||||
|
||||
FILE * open(const std::string & filename, const char * mode) {
|
||||
file = fopen(filename.c_str(), mode);
|
||||
file = ggml_fopen(filename.c_str(), mode);
|
||||
|
||||
return file;
|
||||
}
|
||||
@@ -303,6 +323,20 @@ class File {
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string to_string() {
|
||||
fseek(file, 0, SEEK_END);
|
||||
const size_t size = ftell(file);
|
||||
fseek(file, 0, SEEK_SET);
|
||||
std::string out;
|
||||
out.resize(size);
|
||||
const size_t read_size = fread(&out[0], 1, size, file);
|
||||
if (read_size != size) {
|
||||
printe("Error reading file: %s", strerror(errno));
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
~File() {
|
||||
if (fd >= 0) {
|
||||
# ifdef _WIN32
|
||||
@@ -327,6 +361,7 @@ class File {
|
||||
# endif
|
||||
};
|
||||
|
||||
#ifdef LLAMA_USE_CURL
|
||||
class HttpClient {
|
||||
public:
|
||||
int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
|
||||
@@ -557,7 +592,7 @@ class LlamaData {
|
||||
llama_model_ptr model;
|
||||
llama_sampler_ptr sampler;
|
||||
llama_context_ptr context;
|
||||
std::vector<llama_chat_message> messages;
|
||||
std::vector<llama_chat_message> messages; // TODO: switch to common_chat_msg
|
||||
std::list<std::string> msg_strs;
|
||||
std::vector<char> fmtted;
|
||||
|
||||
@@ -834,44 +869,23 @@ static void add_message(const char * role, const std::string & text, LlamaData &
|
||||
}
|
||||
|
||||
// Function to apply the chat template and resize `formatted` if needed
|
||||
static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
json messages = json::array();
|
||||
for (const auto & msg : llama_data.messages) {
|
||||
messages.push_back({
|
||||
{"role", msg.role},
|
||||
{"content", msg.content},
|
||||
});
|
||||
}
|
||||
try {
|
||||
minja::chat_template_inputs tmpl_inputs;
|
||||
tmpl_inputs.messages = messages;
|
||||
tmpl_inputs.add_generation_prompt = append;
|
||||
|
||||
minja::chat_template_options tmpl_opts;
|
||||
tmpl_opts.use_bos_token = false;
|
||||
tmpl_opts.use_eos_token = false;
|
||||
|
||||
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
|
||||
llama_data.fmtted.resize(result.size() + 1);
|
||||
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
|
||||
return result.size();
|
||||
} catch (const std::exception & e) {
|
||||
printe("failed to render the chat template: %s\n", e.what());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
int result = llama_chat_apply_template(
|
||||
tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
|
||||
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
|
||||
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
|
||||
llama_data.fmtted.resize(result);
|
||||
result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
|
||||
llama_data.messages.size(), append, llama_data.fmtted.data(),
|
||||
llama_data.fmtted.size());
|
||||
static int apply_chat_template(const struct common_chat_templates * tmpls, LlamaData & llama_data, const bool append, bool use_jinja) {
|
||||
common_chat_templates_inputs inputs;
|
||||
for (const auto & msg : llama_data.messages) {
|
||||
common_chat_msg cmsg;
|
||||
cmsg.role = msg.role;
|
||||
cmsg.content = msg.content;
|
||||
inputs.messages.push_back(cmsg);
|
||||
}
|
||||
inputs.add_generation_prompt = append;
|
||||
inputs.use_jinja = use_jinja;
|
||||
|
||||
return result;
|
||||
auto chat_params = common_chat_templates_apply(tmpls, inputs);
|
||||
// TODO: use other params for tool calls.
|
||||
auto result = chat_params.prompt;
|
||||
llama_data.fmtted.resize(result.size() + 1);
|
||||
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
|
||||
return result.size();
|
||||
}
|
||||
|
||||
// Function to tokenize the prompt
|
||||
@@ -963,7 +977,8 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
|
||||
}
|
||||
|
||||
static int read_user_input(std::string & user_input) {
|
||||
static const char * prompt_prefix = "> ";
|
||||
static const char * prompt_prefix_env = std::getenv("LLAMA_PROMPT_PREFIX");
|
||||
static const char * prompt_prefix = prompt_prefix_env ? prompt_prefix_env : "> ";
|
||||
#ifdef WIN32
|
||||
printf("\r" LOG_CLR_TO_EOL LOG_COL_DEFAULT "%s", prompt_prefix);
|
||||
|
||||
@@ -1015,8 +1030,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
|
||||
}
|
||||
|
||||
// Helper function to apply the chat template and handle errors
|
||||
static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
|
||||
const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
|
||||
static int apply_chat_template_with_error_handling(const common_chat_templates * tmpls, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
|
||||
const int new_len = apply_chat_template(tmpls, llama_data, append, use_jinja);
|
||||
if (new_len < 0) {
|
||||
printe("failed to apply the chat template\n");
|
||||
return -1;
|
||||
@@ -1074,40 +1089,68 @@ static int get_user_input(std::string & user_input, const std::string & user) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Reads a chat template file to be used
|
||||
static std::string read_chat_template_file(const std::string & chat_template_file) {
|
||||
File file;
|
||||
if (!file.open(chat_template_file, "r")) {
|
||||
printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
|
||||
return "";
|
||||
}
|
||||
|
||||
return file.to_string();
|
||||
}
|
||||
|
||||
static int process_user_message(const Opt & opt, const std::string & user_input, LlamaData & llama_data,
|
||||
const common_chat_templates_ptr & chat_templates, int & prev_len,
|
||||
const bool stdout_a_terminal) {
|
||||
add_message("user", opt.user.empty() ? user_input : opt.user, llama_data);
|
||||
int new_len;
|
||||
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, opt.use_jinja) < 0) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
|
||||
std::string response;
|
||||
if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (!opt.user.empty()) {
|
||||
return 2;
|
||||
}
|
||||
|
||||
add_message("assistant", response, llama_data);
|
||||
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, opt.use_jinja) < 0) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Main chat loop function
|
||||
static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
|
||||
static int chat_loop(LlamaData & llama_data, const Opt & opt) {
|
||||
int prev_len = 0;
|
||||
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
|
||||
auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), "");
|
||||
GGML_ASSERT(chat_templates.template_default);
|
||||
std::string chat_template;
|
||||
if (!opt.chat_template_file.empty()) {
|
||||
chat_template = read_chat_template_file(opt.chat_template_file);
|
||||
}
|
||||
|
||||
common_chat_templates_ptr chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template);
|
||||
static const bool stdout_a_terminal = is_stdout_a_terminal();
|
||||
while (true) {
|
||||
// Get user input
|
||||
std::string user_input;
|
||||
if (get_user_input(user_input, user) == 1) {
|
||||
if (get_user_input(user_input, opt.user) == 1) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
add_message("user", user.empty() ? user_input : user, llama_data);
|
||||
int new_len;
|
||||
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) {
|
||||
const int ret = process_user_message(opt, user_input, llama_data, chat_templates, prev_len, stdout_a_terminal);
|
||||
if (ret == 1) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
|
||||
std::string response;
|
||||
if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (!user.empty()) {
|
||||
} else if (ret == 2) {
|
||||
break;
|
||||
}
|
||||
|
||||
add_message("assistant", response, llama_data);
|
||||
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
@@ -1165,7 +1208,7 @@ int main(int argc, const char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (chat_loop(llama_data, opt.user, opt.use_jinja)) {
|
||||
if (chat_loop(llama_data, opt)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
@@ -274,7 +274,7 @@ struct server_task {
|
||||
params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
|
||||
|
||||
params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
|
||||
params.speculative.n_min = std::max(params.speculative.n_min, 2);
|
||||
params.speculative.n_min = std::max(params.speculative.n_min, 0);
|
||||
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
||||
|
||||
// Use OpenAI API logprobs only if n_probs wasn't provided
|
||||
@@ -329,9 +329,6 @@ struct server_task {
|
||||
}
|
||||
|
||||
// process "json_schema" and "grammar"
|
||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
||||
}
|
||||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||
try {
|
||||
auto schema = json_value(data, "json_schema", json::object());
|
||||
@@ -1807,7 +1804,7 @@ struct server_context {
|
||||
// Necessary similarity of prompt for slot selection
|
||||
float slot_prompt_similarity = 0.0f;
|
||||
|
||||
common_chat_templates chat_templates;
|
||||
common_chat_templates_ptr chat_templates;
|
||||
|
||||
~server_context() {
|
||||
// Clear any sampling context
|
||||
@@ -1891,45 +1888,17 @@ struct server_context {
|
||||
llama_init_dft.context.reset();
|
||||
}
|
||||
|
||||
if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) {
|
||||
chat_templates = common_chat_templates_init(model, params_base.chat_template);
|
||||
try {
|
||||
common_chat_format_example(chat_templates.get(), params.use_jinja);
|
||||
} catch (const std::exception & e) {
|
||||
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
||||
chat_templates = common_chat_templates_from_model(model, "chatml");
|
||||
} else {
|
||||
chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
|
||||
chat_templates = common_chat_templates_init(model, "chatml");
|
||||
}
|
||||
GGML_ASSERT(chat_templates.template_default.get() != nullptr);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool validate_builtin_chat_template(bool use_jinja) const {
|
||||
llama_chat_message chat[] = {{"user", "test"}};
|
||||
|
||||
if (use_jinja) {
|
||||
auto templates = common_chat_templates_from_model(model, "");
|
||||
common_chat_inputs inputs;
|
||||
inputs.messages = json::array({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}});
|
||||
GGML_ASSERT(templates.template_default);
|
||||
try {
|
||||
common_chat_params_init(*templates.template_default, inputs);
|
||||
if (templates.template_tool_use) {
|
||||
common_chat_params_init(*templates.template_tool_use, inputs);
|
||||
}
|
||||
return true;
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("failed to apply template: %s\n", e.what());
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
|
||||
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
|
||||
return chat_res > 0;
|
||||
}
|
||||
}
|
||||
|
||||
void init() {
|
||||
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
|
||||
|
||||
@@ -3656,7 +3625,7 @@ int main(int argc, char ** argv) {
|
||||
}, {
|
||||
{"name", "n_busy_slots_per_decode"},
|
||||
{"help", "Average number of busy slots per llama_decode() call"},
|
||||
{"value", (float) res_metrics->n_busy_slots_total / (float) res_metrics->n_decode_total}
|
||||
{"value", (float) res_metrics->n_busy_slots_total / std::max((float) res_metrics->n_decode_total, 1.f)}
|
||||
}}},
|
||||
{"gauge", {{
|
||||
{"name", "prompt_tokens_seconds"},
|
||||
@@ -3822,13 +3791,15 @@ int main(int argc, char ** argv) {
|
||||
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
||||
{ "total_slots", ctx_server.params_base.n_parallel },
|
||||
{ "model_path", ctx_server.params_base.model },
|
||||
{ "chat_template", ctx_server.chat_templates.template_default->source() },
|
||||
{ "bos_token", ctx_server.chat_templates.template_default->bos_token() },
|
||||
{ "eos_token", ctx_server.chat_templates.template_default->eos_token() },
|
||||
{ "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
|
||||
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
|
||||
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
|
||||
{ "build_info", build_info },
|
||||
};
|
||||
if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) {
|
||||
data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source();
|
||||
if (ctx_server.params_base.use_jinja) {
|
||||
if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) {
|
||||
data["chat_template_tool_use"] = tool_use_src;
|
||||
}
|
||||
}
|
||||
|
||||
res_ok(res, data);
|
||||
@@ -4063,7 +4034,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
auto body = json::parse(req.body);
|
||||
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
|
||||
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
|
||||
|
||||
return handle_completions_impl(
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
@@ -4076,7 +4047,7 @@ int main(int argc, char ** argv) {
|
||||
// same with handle_chat_completions, but without inference part
|
||||
const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
auto body = json::parse(req.body);
|
||||
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
|
||||
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
|
||||
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
|
||||
};
|
||||
|
||||
@@ -4263,6 +4234,11 @@ int main(int argc, char ** argv) {
|
||||
// return;
|
||||
//}
|
||||
|
||||
// if true, use TEI API format, otherwise use Jina API format
|
||||
// Jina: https://jina.ai/reranker/
|
||||
// TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank
|
||||
bool is_tei_format = body.contains("texts");
|
||||
|
||||
json query;
|
||||
if (body.count("query") == 1) {
|
||||
query = body.at("query");
|
||||
@@ -4275,7 +4251,8 @@ int main(int argc, char ** argv) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>());
|
||||
std::vector<std::string> documents = json_value(body, "documents",
|
||||
json_value(body, "texts", std::vector<std::string>()));
|
||||
if (documents.empty()) {
|
||||
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
@@ -4320,7 +4297,12 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// write JSON response
|
||||
json root = format_response_rerank(body, responses);
|
||||
json root = format_response_rerank(
|
||||
body,
|
||||
responses,
|
||||
is_tei_format,
|
||||
documents);
|
||||
|
||||
res_ok(res, root);
|
||||
};
|
||||
|
||||
@@ -4482,8 +4464,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// print sample chat example to make it clear which template is used
|
||||
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
|
||||
ctx_server.chat_templates.template_default->source().c_str(),
|
||||
common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str());
|
||||
common_chat_templates_source(ctx_server.chat_templates.get()),
|
||||
common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str());
|
||||
|
||||
ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
|
||||
ctx_server.process_single_task(task);
|
||||
|
||||
@@ -48,7 +48,7 @@ DEBUG=1 ./tests.sh -s -v -x
|
||||
To run all the tests in a file:
|
||||
|
||||
```shell
|
||||
./tests.sh unit/test_chat_completion.py.py -v -x
|
||||
./tests.sh unit/test_chat_completion.py -v -x
|
||||
```
|
||||
|
||||
To run a single test:
|
||||
|
||||
@@ -21,6 +21,8 @@ def create_server():
|
||||
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
|
||||
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
|
||||
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
|
||||
]
|
||||
)
|
||||
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
|
||||
@@ -44,7 +46,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
|
||||
assert res.body["usage"]["completion_tokens"] == n_predicted
|
||||
choice = res.body["choices"][0]
|
||||
assert "assistant" == choice["message"]["role"]
|
||||
assert match_regex(re_content, choice["message"]["content"])
|
||||
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
|
||||
assert choice["finish_reason"] == finish_reason
|
||||
|
||||
|
||||
@@ -169,6 +171,47 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
|
||||
assert "error" in res.body
|
||||
|
||||
|
||||
@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
|
||||
(False, {"const": "42"}, 6, "\"42\""),
|
||||
(True, {"const": "42"}, 6, "\"42\""),
|
||||
])
|
||||
def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
|
||||
global server
|
||||
server.jinja = jinja
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": n_predicted,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
{"role": "user", "content": "Write an example"},
|
||||
],
|
||||
"json_schema": json_schema,
|
||||
})
|
||||
assert res.status_code == 200, f'Expected 200, got {res.status_code}'
|
||||
choice = res.body["choices"][0]
|
||||
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
|
||||
|
||||
|
||||
@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
|
||||
(False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
|
||||
(True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
|
||||
])
|
||||
def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
|
||||
global server
|
||||
server.jinja = jinja
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": n_predicted,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Does not matter what I say, does it?"},
|
||||
],
|
||||
"grammar": grammar,
|
||||
})
|
||||
assert res.status_code == 200, res.body
|
||||
choice = res.body["choices"][0]
|
||||
assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("messages", [
|
||||
None,
|
||||
"string",
|
||||
|
||||
@@ -10,17 +10,20 @@ def create_server():
|
||||
server = ServerPreset.jina_reranker_tiny()
|
||||
|
||||
|
||||
TEST_DOCUMENTS = [
|
||||
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
|
||||
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
|
||||
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
|
||||
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
|
||||
]
|
||||
|
||||
|
||||
def test_rerank():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/rerank", data={
|
||||
"query": "Machine learning is",
|
||||
"documents": [
|
||||
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
|
||||
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
|
||||
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
|
||||
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
|
||||
]
|
||||
"documents": TEST_DOCUMENTS,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body["results"]) == 4
|
||||
@@ -38,6 +41,29 @@ def test_rerank():
|
||||
assert least_relevant["index"] == 3
|
||||
|
||||
|
||||
def test_rerank_tei_format():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/rerank", data={
|
||||
"query": "Machine learning is",
|
||||
"texts": TEST_DOCUMENTS,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body) == 4
|
||||
|
||||
most_relevant = res.body[0]
|
||||
least_relevant = res.body[0]
|
||||
for doc in res.body:
|
||||
if doc["score"] > most_relevant["score"]:
|
||||
most_relevant = doc
|
||||
if doc["score"] < least_relevant["score"]:
|
||||
least_relevant = doc
|
||||
|
||||
assert most_relevant["score"] > least_relevant["score"]
|
||||
assert most_relevant["index"] == 2
|
||||
assert least_relevant["index"] == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize("documents", [
|
||||
[],
|
||||
None,
|
||||
|
||||
@@ -356,12 +356,12 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
|
||||
(None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
|
||||
(None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
("^> 0.56$", 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
|
||||
# TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
|
||||
("^The y-coordinate [\\s\\S]*?\\*\\*0.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
("[\\s\\S]*?\\*\\*0\\.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
# ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
])
|
||||
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
global server
|
||||
@@ -401,7 +401,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "calculate",
|
||||
"content": 0.55644242476,
|
||||
"content": "0.55644242476",
|
||||
"tool_call_id": "call_6789"
|
||||
}
|
||||
],
|
||||
@@ -444,7 +444,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
|
||||
(128, None, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
|
||||
(1024, 'deepseek', "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(1024, 'none', "<think>\n?I need[\\s\\S]*?</think>\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(1024, 'none', "^I need[\\s\\S]*?</think>\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
|
||||
(1024, 'deepseek', "To find the sum of.*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
])
|
||||
|
||||
@@ -7,14 +7,14 @@
|
||||
|
||||
// increase max payload length to allow use of larger context size
|
||||
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
|
||||
// disable Nagle's algorithm
|
||||
#define CPPHTTPLIB_TCP_NODELAY true
|
||||
#include "httplib.h"
|
||||
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
#include "minja.hpp"
|
||||
#include "chat.hpp"
|
||||
#include "chat-template.hpp"
|
||||
#include "chat.h"
|
||||
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
@@ -347,41 +347,6 @@ static llama_tokens format_infill(
|
||||
return embd_inp;
|
||||
}
|
||||
|
||||
// Format given chat. If tmpl is empty, we take the template from model metadata
|
||||
inline std::string format_chat(const common_chat_template & tmpl, const std::vector<json> & messages) {
|
||||
std::vector<common_chat_msg> chat;
|
||||
|
||||
for (size_t i = 0; i < messages.size(); ++i) {
|
||||
const auto & curr_msg = messages[i];
|
||||
|
||||
std::string role = json_value(curr_msg, "role", std::string(""));
|
||||
|
||||
std::string content;
|
||||
if (curr_msg.contains("content")) {
|
||||
if (curr_msg["content"].is_string()) {
|
||||
content = curr_msg["content"].get<std::string>();
|
||||
} else if (curr_msg["content"].is_array()) {
|
||||
for (const auto & part : curr_msg["content"]) {
|
||||
if (part.contains("text")) {
|
||||
content += "\n" + part["text"].get<std::string>();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
|
||||
}
|
||||
|
||||
chat.push_back({role, content, /* tool_calls= */ {}});
|
||||
}
|
||||
|
||||
const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);
|
||||
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
|
||||
|
||||
return formatted_chat;
|
||||
}
|
||||
|
||||
//
|
||||
// base64 utils (TODO: move to common in the future)
|
||||
//
|
||||
@@ -579,12 +544,9 @@ static json oaicompat_completion_params_parse(
|
||||
const json & body, /* openai api json semantics */
|
||||
bool use_jinja,
|
||||
common_reasoning_format reasoning_format,
|
||||
const common_chat_templates & chat_templates)
|
||||
const struct common_chat_templates * tmpls)
|
||||
{
|
||||
json llama_params;
|
||||
const auto & tmpl = body.contains("tools") && chat_templates.template_tool_use
|
||||
? *chat_templates.template_tool_use
|
||||
: *chat_templates.template_default;
|
||||
|
||||
auto tools = json_value(body, "tools", json());
|
||||
auto stream = json_value(body, "stream", false);
|
||||
@@ -610,62 +572,58 @@ static json oaicompat_completion_params_parse(
|
||||
llama_params["stop"] = json_value(body, "stop", json::array());
|
||||
}
|
||||
|
||||
auto json_schema = json_value(body, "json_schema", json());
|
||||
auto grammar = json_value(body, "grammar", std::string());
|
||||
if (!json_schema.is_null() && !grammar.empty()) {
|
||||
throw std::runtime_error("Cannot use both json_schema and grammar");
|
||||
}
|
||||
|
||||
// Handle "response_format" field
|
||||
if (body.contains("response_format")) {
|
||||
json response_format = json_value(body, "response_format", json::object());
|
||||
std::string response_type = json_value(response_format, "type", std::string());
|
||||
if (response_type == "json_object") {
|
||||
llama_params["json_schema"] = json_value(response_format, "schema", json::object());
|
||||
json_schema = json_value(response_format, "schema", json::object());
|
||||
} else if (response_type == "json_schema") {
|
||||
json json_schema = json_value(response_format, "json_schema", json::object());
|
||||
llama_params["json_schema"] = json_value(json_schema, "schema", json::object());
|
||||
json_schema = json_value(json_schema, "schema", json::object());
|
||||
} else if (!response_type.empty() && response_type != "text") {
|
||||
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply chat template to the list of messages
|
||||
if (use_jinja) {
|
||||
auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
|
||||
if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
|
||||
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
|
||||
}
|
||||
if (tool_choice != "none" && llama_params.contains("grammar")) {
|
||||
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
||||
}
|
||||
common_chat_inputs inputs;
|
||||
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
inputs.messages = body.at("messages");
|
||||
inputs.tools = tools;
|
||||
inputs.tool_choice = tool_choice;
|
||||
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
||||
if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
|
||||
LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
|
||||
inputs.parallel_tool_calls = false;
|
||||
}
|
||||
inputs.stream = stream;
|
||||
// TODO: support mixing schema w/ tools beyond generic format.
|
||||
inputs.json_schema = json_value(llama_params, "json_schema", json());
|
||||
auto chat_params = common_chat_params_init(tmpl, inputs);
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages"));
|
||||
inputs.tools = common_chat_tools_parse_oaicompat(tools);
|
||||
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
|
||||
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
|
||||
inputs.grammar = grammar;
|
||||
inputs.add_generation_prompt = true;
|
||||
inputs.use_jinja = use_jinja;
|
||||
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
||||
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
|
||||
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
||||
}
|
||||
|
||||
llama_params["chat_format"] = static_cast<int>(chat_params.format);
|
||||
llama_params["prompt"] = chat_params.prompt;
|
||||
llama_params["grammar"] = chat_params.grammar;
|
||||
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
|
||||
auto grammar_triggers = json::array();
|
||||
for (const auto & trigger : chat_params.grammar_triggers) {
|
||||
grammar_triggers.push_back({
|
||||
{"word", trigger.word},
|
||||
{"at_start", trigger.at_start},
|
||||
});
|
||||
}
|
||||
llama_params["grammar_triggers"] = grammar_triggers;
|
||||
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
|
||||
for (const auto & stop : chat_params.additional_stops) {
|
||||
llama_params["stop"].push_back(stop);
|
||||
}
|
||||
} else {
|
||||
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
|
||||
// Apply chat template to the list of messages
|
||||
auto chat_params = common_chat_templates_apply(tmpls, inputs);
|
||||
|
||||
llama_params["chat_format"] = static_cast<int>(chat_params.format);
|
||||
llama_params["prompt"] = chat_params.prompt;
|
||||
llama_params["grammar"] = chat_params.grammar;
|
||||
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
|
||||
auto grammar_triggers = json::array();
|
||||
for (const auto & trigger : chat_params.grammar_triggers) {
|
||||
grammar_triggers.push_back({
|
||||
{"word", trigger.word},
|
||||
{"at_start", trigger.at_start},
|
||||
});
|
||||
}
|
||||
llama_params["grammar_triggers"] = grammar_triggers;
|
||||
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
|
||||
for (const auto & stop : chat_params.additional_stops) {
|
||||
llama_params["stop"].push_back(stop);
|
||||
}
|
||||
|
||||
// Handle "n" field
|
||||
@@ -737,29 +695,51 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
|
||||
return res;
|
||||
}
|
||||
|
||||
static json format_response_rerank(const json & request, const json & ranks) {
|
||||
json data = json::array();
|
||||
int32_t n_tokens = 0;
|
||||
int i = 0;
|
||||
for (const auto & rank : ranks) {
|
||||
data.push_back(json{
|
||||
{"index", i++},
|
||||
{"relevance_score", json_value(rank, "score", 0.0)},
|
||||
});
|
||||
static json format_response_rerank(
|
||||
const json & request,
|
||||
const json & ranks,
|
||||
bool is_tei_format,
|
||||
std::vector<std::string> & texts) {
|
||||
json res;
|
||||
if (is_tei_format) {
|
||||
// TEI response format
|
||||
res = json::array();
|
||||
bool return_text = json_value(request, "return_text", false);
|
||||
for (const auto & rank : ranks) {
|
||||
int index = json_value(rank, "index", 0);
|
||||
json elem = json{
|
||||
{"index", index},
|
||||
{"score", json_value(rank, "score", 0.0)},
|
||||
};
|
||||
if (return_text) {
|
||||
elem["text"] = std::move(texts[index]);
|
||||
}
|
||||
res.push_back(elem);
|
||||
}
|
||||
} else {
|
||||
// Jina response format
|
||||
json results = json::array();
|
||||
int32_t n_tokens = 0;
|
||||
for (const auto & rank : ranks) {
|
||||
results.push_back(json{
|
||||
{"index", json_value(rank, "index", 0)},
|
||||
{"relevance_score", json_value(rank, "score", 0.0)},
|
||||
});
|
||||
|
||||
n_tokens += json_value(rank, "tokens_evaluated", 0);
|
||||
n_tokens += json_value(rank, "tokens_evaluated", 0);
|
||||
}
|
||||
|
||||
res = json{
|
||||
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||
{"object", "list"},
|
||||
{"usage", json{
|
||||
{"prompt_tokens", n_tokens},
|
||||
{"total_tokens", n_tokens}
|
||||
}},
|
||||
{"results", results}
|
||||
};
|
||||
}
|
||||
|
||||
json res = json {
|
||||
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||
{"object", "list"},
|
||||
{"usage", json {
|
||||
{"prompt_tokens", n_tokens},
|
||||
{"total_tokens", n_tokens}
|
||||
}},
|
||||
{"results", data}
|
||||
};
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
@@ -159,6 +159,35 @@ export default function ChatMessage({
|
||||
</div>
|
||||
</details>
|
||||
)}
|
||||
|
||||
{msg.extra && msg.extra.length > 0 && (
|
||||
<details
|
||||
className={classNames({
|
||||
'collapse collapse-arrow mb-4 bg-base-200': true,
|
||||
'bg-opacity-10': msg.role !== 'assistant',
|
||||
})}
|
||||
>
|
||||
<summary className="collapse-title">
|
||||
Extra content
|
||||
</summary>
|
||||
<div className="collapse-content">
|
||||
{msg.extra.map(
|
||||
(extra, i) =>
|
||||
extra.type === 'textFile' ? (
|
||||
<div key={extra.name}>
|
||||
<b>{extra.name}</b>
|
||||
<pre>{extra.content}</pre>
|
||||
</div>
|
||||
) : extra.type === 'context' ? (
|
||||
<div key={i}>
|
||||
<pre>{extra.content}</pre>
|
||||
</div>
|
||||
) : null // TODO: support other extra types
|
||||
)}
|
||||
</div>
|
||||
</details>
|
||||
)}
|
||||
|
||||
<MarkdownDisplay
|
||||
content={content}
|
||||
isGenerating={isPending}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import { useEffect, useMemo, useState } from 'react';
|
||||
import { useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
|
||||
import ChatMessage from './ChatMessage';
|
||||
import { CanvasType, Message, PendingMessage } from '../utils/types';
|
||||
import { classNames, throttle } from '../utils/misc';
|
||||
import CanvasPyInterpreter from './CanvasPyInterpreter';
|
||||
import StorageUtils from '../utils/storage';
|
||||
import { useVSCodeContext } from '../utils/llama-vscode';
|
||||
|
||||
/**
|
||||
* A message display is a message node with additional information for rendering.
|
||||
@@ -81,6 +82,14 @@ export default function ChatScreen() {
|
||||
replaceMessageAndGenerate,
|
||||
} = useAppContext();
|
||||
const [inputMsg, setInputMsg] = useState('');
|
||||
const inputRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
const { extraContext, clearExtraContext } = useVSCodeContext(
|
||||
inputRef,
|
||||
setInputMsg
|
||||
);
|
||||
// TODO: improve this when we have "upload file" feature
|
||||
const currExtra: Message['extra'] = extraContext ? [extraContext] : undefined;
|
||||
|
||||
// keep track of leaf node for rendering
|
||||
const [currNodeId, setCurrNodeId] = useState<number>(-1);
|
||||
@@ -115,10 +124,20 @@ export default function ChatScreen() {
|
||||
setCurrNodeId(-1);
|
||||
// get the last message node
|
||||
const lastMsgNodeId = messages.at(-1)?.msg.id ?? null;
|
||||
if (!(await sendMessage(currConvId, lastMsgNodeId, inputMsg, onChunk))) {
|
||||
if (
|
||||
!(await sendMessage(
|
||||
currConvId,
|
||||
lastMsgNodeId,
|
||||
inputMsg,
|
||||
currExtra,
|
||||
onChunk
|
||||
))
|
||||
) {
|
||||
// restore the input message if failed
|
||||
setInputMsg(lastInpMsg);
|
||||
}
|
||||
// OK
|
||||
clearExtraContext();
|
||||
};
|
||||
|
||||
const handleEditMessage = async (msg: Message, content: string) => {
|
||||
@@ -129,6 +148,7 @@ export default function ChatScreen() {
|
||||
viewingChat.conv.id,
|
||||
msg.parent,
|
||||
content,
|
||||
msg.extra,
|
||||
onChunk
|
||||
);
|
||||
setCurrNodeId(-1);
|
||||
@@ -143,6 +163,7 @@ export default function ChatScreen() {
|
||||
viewingChat.conv.id,
|
||||
msg.parent,
|
||||
null,
|
||||
msg.extra,
|
||||
onChunk
|
||||
);
|
||||
setCurrNodeId(-1);
|
||||
@@ -203,9 +224,11 @@ export default function ChatScreen() {
|
||||
<textarea
|
||||
className="textarea textarea-bordered w-full"
|
||||
placeholder="Type a message (Shift+Enter to add a new line)"
|
||||
ref={inputRef}
|
||||
value={inputMsg}
|
||||
onChange={(e) => setInputMsg(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.nativeEvent.isComposing || e.keyCode === 229) return;
|
||||
if (e.key === 'Enter' && e.shiftKey) return;
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
|
||||
@@ -25,6 +25,7 @@ interface AppContextValue {
|
||||
convId: string | null,
|
||||
leafNodeId: Message['id'] | null,
|
||||
content: string,
|
||||
extra: Message['extra'],
|
||||
onChunk: CallbackGeneratedChunk
|
||||
) => Promise<boolean>;
|
||||
stopGenerating: (convId: string) => void;
|
||||
@@ -32,6 +33,7 @@ interface AppContextValue {
|
||||
convId: string,
|
||||
parentNodeId: Message['id'], // the parent node of the message to be replaced
|
||||
content: string | null,
|
||||
extra: Message['extra'],
|
||||
onChunk: CallbackGeneratedChunk
|
||||
) => Promise<void>;
|
||||
|
||||
@@ -274,6 +276,7 @@ export const AppContextProvider = ({
|
||||
convId: string | null,
|
||||
leafNodeId: Message['id'] | null,
|
||||
content: string,
|
||||
extra: Message['extra'],
|
||||
onChunk: CallbackGeneratedChunk
|
||||
): Promise<boolean> => {
|
||||
if (isGenerating(convId ?? '') || content.trim().length === 0) return false;
|
||||
@@ -298,6 +301,7 @@ export const AppContextProvider = ({
|
||||
convId,
|
||||
role: 'user',
|
||||
content,
|
||||
extra,
|
||||
parent: leafNodeId,
|
||||
children: [],
|
||||
},
|
||||
@@ -324,6 +328,7 @@ export const AppContextProvider = ({
|
||||
convId: string,
|
||||
parentNodeId: Message['id'], // the parent node of the message to be replaced
|
||||
content: string | null,
|
||||
extra: Message['extra'],
|
||||
onChunk: CallbackGeneratedChunk
|
||||
) => {
|
||||
if (isGenerating(convId)) return;
|
||||
@@ -339,6 +344,7 @@ export const AppContextProvider = ({
|
||||
convId,
|
||||
role: 'user',
|
||||
content,
|
||||
extra,
|
||||
parent: parentNodeId,
|
||||
children: [],
|
||||
},
|
||||
|
||||
62
examples/server/webui/src/utils/llama-vscode.ts
Normal file
62
examples/server/webui/src/utils/llama-vscode.ts
Normal file
@@ -0,0 +1,62 @@
|
||||
import { useEffect, useState } from 'react';
|
||||
import { MessageExtraContext } from './types';
|
||||
|
||||
// Extra context when using llama.cpp WebUI from llama-vscode, inside an iframe
|
||||
// Ref: https://github.com/ggml-org/llama.cpp/pull/11940
|
||||
|
||||
interface SetTextEvData {
|
||||
text: string;
|
||||
context: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* To test it:
|
||||
* window.postMessage({ command: 'setText', text: 'Spot the syntax error', context: 'def test()\n return 123' }, '*');
|
||||
*/
|
||||
|
||||
export const useVSCodeContext = (
|
||||
inputRef: React.RefObject<HTMLTextAreaElement>,
|
||||
setInputMsg: (text: string) => void
|
||||
) => {
|
||||
const [extraContext, setExtraContext] = useState<MessageExtraContext | null>(
|
||||
null
|
||||
);
|
||||
|
||||
// Accept setText message from a parent window and set inputMsg and extraContext
|
||||
useEffect(() => {
|
||||
const handleMessage = (event: MessageEvent) => {
|
||||
if (event.data?.command === 'setText') {
|
||||
const data: SetTextEvData = event.data;
|
||||
setInputMsg(data?.text);
|
||||
if (data?.context && data.context.length > 0) {
|
||||
setExtraContext({
|
||||
type: 'context',
|
||||
content: data.context,
|
||||
});
|
||||
}
|
||||
inputRef.current?.focus();
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('message', handleMessage);
|
||||
return () => window.removeEventListener('message', handleMessage);
|
||||
}, [inputRef, setInputMsg]);
|
||||
|
||||
// Add a keydown listener that sends the "escapePressed" message to the parent window
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
if (event.key === 'Escape') {
|
||||
window.parent.postMessage({ command: 'escapePressed' }, '*');
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('keydown', handleKeyDown);
|
||||
return () => window.removeEventListener('keydown', handleKeyDown);
|
||||
}, []);
|
||||
|
||||
return {
|
||||
extraContext,
|
||||
// call once the user message is sent, to clear the extra context
|
||||
clearExtraContext: () => setExtraContext(null),
|
||||
};
|
||||
};
|
||||
@@ -53,12 +53,23 @@ export const copyStr = (textToCopy: string) => {
|
||||
|
||||
/**
|
||||
* filter out redundant fields upon sending to API
|
||||
* also format extra into text
|
||||
*/
|
||||
export function normalizeMsgsForAPI(messages: Readonly<Message[]>) {
|
||||
return messages.map((msg) => {
|
||||
let newContent = '';
|
||||
|
||||
for (const extra of msg.extra ?? []) {
|
||||
if (extra.type === 'context') {
|
||||
newContent += `${extra.content}\n\n`;
|
||||
}
|
||||
}
|
||||
|
||||
newContent += msg.content;
|
||||
|
||||
return {
|
||||
role: msg.role,
|
||||
content: msg.content,
|
||||
content: newContent,
|
||||
};
|
||||
}) as APIMessage[];
|
||||
}
|
||||
|
||||
@@ -42,11 +42,25 @@ export interface Message {
|
||||
role: 'user' | 'assistant' | 'system';
|
||||
content: string;
|
||||
timings?: TimingReport;
|
||||
extra?: MessageExtra[];
|
||||
// node based system for branching
|
||||
parent: Message['id'];
|
||||
children: Message['id'][];
|
||||
}
|
||||
|
||||
type MessageExtra = MessageExtraTextFile | MessageExtraContext; // TODO: will add more in the future
|
||||
|
||||
export interface MessageExtraTextFile {
|
||||
type: 'textFile';
|
||||
name: string;
|
||||
content: string;
|
||||
}
|
||||
|
||||
export interface MessageExtraContext {
|
||||
type: 'context';
|
||||
content: string;
|
||||
}
|
||||
|
||||
export type APIMessage = Pick<Message, 'role' | 'content'>;
|
||||
|
||||
export interface Conversation {
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# MIT license
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
export ONEAPI_DEVICE_SELECTOR="level_zero:0"
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
|
||||
#export GGML_SYCL_DEBUG=1
|
||||
@@ -13,7 +13,7 @@ source /opt/intel/oneapi/setvars.sh
|
||||
INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:"
|
||||
MODEL_FILE=models/llama-2-7b.Q4_0.gguf
|
||||
NGL=33
|
||||
CONEXT=8192
|
||||
CONEXT=4096
|
||||
|
||||
if [ $# -gt 0 ]; then
|
||||
GGML_SYCL_DEVICE=$1
|
||||
|
||||
@@ -102,6 +102,7 @@ endif()
|
||||
|
||||
option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
|
||||
option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
|
||||
option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable" OFF)
|
||||
option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
|
||||
option(GGML_AVX_VNNI "ggml: enable AVX-VNNI" OFF)
|
||||
option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB})
|
||||
@@ -121,6 +122,7 @@ endif()
|
||||
option(GGML_LASX "ggml: enable lasx" ON)
|
||||
option(GGML_LSX "ggml: enable lsx" ON)
|
||||
option(GGML_RVV "ggml: enable rvv" ON)
|
||||
option(GGML_VXE "ggml: enable vxe" ON)
|
||||
|
||||
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
|
||||
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
|
||||
@@ -150,6 +152,7 @@ set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
|
||||
"ggml: max. batch size for using peer access")
|
||||
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
|
||||
option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
|
||||
option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON)
|
||||
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
|
||||
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
|
||||
|
||||
|
||||
@@ -95,9 +95,11 @@ extern "C" {
|
||||
GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void);
|
||||
GGML_BACKEND_API int ggml_cpu_has_sve (void);
|
||||
GGML_BACKEND_API int ggml_cpu_get_sve_cnt (void); // sve vector length in bytes
|
||||
GGML_BACKEND_API int ggml_cpu_has_sme (void);
|
||||
// other
|
||||
GGML_BACKEND_API int ggml_cpu_has_riscv_v (void);
|
||||
GGML_BACKEND_API int ggml_cpu_has_vsx (void);
|
||||
GGML_BACKEND_API int ggml_cpu_has_vxe (void);
|
||||
GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void);
|
||||
GGML_BACKEND_API int ggml_cpu_has_llamafile (void);
|
||||
|
||||
|
||||
@@ -111,14 +111,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
function(check_arm_feature tag code)
|
||||
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
||||
set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+${tag}")
|
||||
check_cxx_source_runs(
|
||||
"${code}"
|
||||
GGML_MACHINE_SUPPORTS_${tag}
|
||||
)
|
||||
check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag})
|
||||
if (GGML_MACHINE_SUPPORTS_${tag})
|
||||
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE)
|
||||
else()
|
||||
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
|
||||
set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+no${tag}")
|
||||
check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag})
|
||||
if (GGML_MACHINE_SUPPORTS_no${tag})
|
||||
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
|
||||
endif()
|
||||
endif()
|
||||
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
||||
endfunction()
|
||||
@@ -126,6 +127,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
check_arm_feature(dotprod "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
|
||||
check_arm_feature(i8mm "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
|
||||
check_arm_feature(sve "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
|
||||
check_arm_feature(sme "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
|
||||
|
||||
list(APPEND ARCH_FLAGS "${ARM_MCPU_FLAG}${ARM_MCPU_FLAG_FIX}")
|
||||
else()
|
||||
@@ -150,7 +152,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
if (ARM_FEATURE_RESULT)
|
||||
message(WARNING "Failed to get ARM features")
|
||||
else()
|
||||
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC)
|
||||
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
|
||||
string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
|
||||
if (NOT ${feature_pos} EQUAL -1)
|
||||
message(STATUS "ARM feature ${feature} enabled")
|
||||
@@ -308,6 +310,27 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
if (GGML_RVV)
|
||||
list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d)
|
||||
endif()
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
|
||||
message(STATUS "s390x detected")
|
||||
file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
|
||||
string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
|
||||
|
||||
# TODO: Separation to determine activation of VX/VXE/VXE2
|
||||
if (${S390X_M} MATCHES "8561|8562")
|
||||
message(STATUS "z15 target")
|
||||
list(APPEND ARCH_FLAGS -march=z15 -mtune=z15)
|
||||
elseif (${S390X_M} MATCHES "3931")
|
||||
message(STATUS "z16 target")
|
||||
list(APPEND ARCH_FLAGS -march=z16 -mtune=z16)
|
||||
else()
|
||||
message(STATUS "Unknown target")
|
||||
message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.")
|
||||
list(APPEND ARCH_FLAGS -march=native -mtune=native)
|
||||
endif()
|
||||
|
||||
if (GGML_VXE)
|
||||
list(APPEND ARCH_FLAGS -mvx -mzvector)
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "Unknown architecture")
|
||||
endif()
|
||||
@@ -316,6 +339,94 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_AARCH64)
|
||||
endif()
|
||||
|
||||
if (GGML_CPU_KLEIDIAI)
|
||||
message(STATUS "Using KleidiAI optimized kernels if applicable")
|
||||
|
||||
# Disable the KleidiAI tests
|
||||
set(KLEIDIAI_BUILD_TESTS OFF)
|
||||
|
||||
# Fetch KleidiAI sources:
|
||||
include(FetchContent)
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.3.0")
|
||||
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "060bd2dc64642b091f461cc8dd7426d9")
|
||||
|
||||
if (POLICY CMP0135)
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(KleidiAI_Download
|
||||
URL ${KLEIDIAI_DOWNLOAD_URL}
|
||||
DOWNLOAD_EXTRACT_TIMESTAMP NEW
|
||||
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})
|
||||
|
||||
FetchContent_MakeAvailable(KleidiAI_Download)
|
||||
FetchContent_GetProperties(KleidiAI_Download
|
||||
SOURCE_DIR KLEIDIAI_SRC
|
||||
POPULATED KLEIDIAI_POPULATED)
|
||||
|
||||
if (NOT KLEIDIAI_POPULATED)
|
||||
message(FATAL_ERROR "KleidiAI source downloaded failed.")
|
||||
endif()
|
||||
|
||||
add_compile_definitions(GGML_USE_CPU_KLEIDIAI)
|
||||
|
||||
# Remove kleidiai target after fetching it
|
||||
if (TARGET kleidiai)
|
||||
set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE)
|
||||
endif()
|
||||
|
||||
list(APPEND GGML_CPU_SOURCES
|
||||
ggml-cpu/kleidiai/kleidiai.cpp
|
||||
ggml-cpu/kleidiai/kernels.cpp
|
||||
ggml-cpu/kleidiai/kleidiai.h
|
||||
ggml-cpu/kleidiai/kernels.h
|
||||
)
|
||||
|
||||
# KleidiAI
|
||||
include_directories(
|
||||
${KLEIDIAI_SRC}/
|
||||
${KLEIDIAI_SRC}/kai/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
|
||||
|
||||
set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}")
|
||||
if (NOT ARCH_FLAGS_TEMP)
|
||||
string(REGEX MATCH "-march=[^ ]+" ARCH_FLAGS_TEMP "${CMAKE_C_FLAGS}")
|
||||
endif()
|
||||
string(FIND "${ARCH_FLAGS_TEMP}" "+dotprod" DOTPROD_ENABLED)
|
||||
string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
|
||||
string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
|
||||
|
||||
set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS})
|
||||
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
|
||||
|
||||
if (NOT DOTPROD_ENABLED MATCHES -1)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
|
||||
endif()
|
||||
|
||||
if (NOT I8MM_ENABLED MATCHES -1)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c)
|
||||
endif()
|
||||
|
||||
if (NOT SME_ENABLED MATCHES -1)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c)
|
||||
set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2")
|
||||
endif()
|
||||
|
||||
set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}")
|
||||
list(APPEND GGML_CPU_SOURCES ${GGML_KLEIDIAI_SOURCES})
|
||||
endif()
|
||||
|
||||
message(STATUS "Adding CPU backend variant ${GGML_CPU_NAME}: ${ARCH_FLAGS} ${ARCH_DEFINITIONS}")
|
||||
target_sources(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_SOURCES})
|
||||
target_compile_options(${GGML_CPU_NAME} PRIVATE ${ARCH_FLAGS})
|
||||
|
||||
@@ -59,6 +59,15 @@ struct ggml_compute_params {
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(__s390x__) && defined(__VEC__)
|
||||
#ifndef __VXE__
|
||||
#define __VXE__
|
||||
#endif
|
||||
#ifndef __VXE2__
|
||||
#define __VXE2__
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
#include <arm_sve.h>
|
||||
#include <sys/prctl.h>
|
||||
@@ -359,6 +368,148 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(__VXE__) || defined(__VXE2__)
|
||||
#include <vecintrin.h>
|
||||
|
||||
#define vec_neg(a) (-(a)) // Vector Negate
|
||||
#define vec_add(a, b) ((a) + (b)) // Vector Add
|
||||
#define vec_sub(a, b) ((a) - (b)) // Vector Subtract
|
||||
#define vec_mul(a, b) ((a) * (b)) // Vector Multiply
|
||||
#define vec_div(a, b) ((a) / (b)) // Vector Divide
|
||||
#define vec_sl(a, b) ((a) << (b)) // Vector Shift Left
|
||||
#define vec_sra(a, b) ((a) >> (b)) // Vector Shift Right
|
||||
#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebraic
|
||||
#define vec_slo(a, b) vec_slb(a, (b) << 64) // Vector Shift Left by Octet
|
||||
#define vec_sro(a, b) vec_srb(a, (b) << 64) // Vector Shift Right by Octet
|
||||
|
||||
#ifndef vec_and
|
||||
#define vec_and(a, b) ((a) & (b)) // Vector AND
|
||||
#endif
|
||||
|
||||
#ifndef vec_or
|
||||
#define vec_or(a, b) ((a) | (b)) // Vector OR
|
||||
#endif
|
||||
|
||||
#ifndef vec_xor
|
||||
#define vec_xor(a, b) ((a) ^ (b)) // Vector XOR
|
||||
#endif
|
||||
|
||||
typedef signed char char8x16_t __attribute__((vector_size(16)));
|
||||
typedef unsigned char uchar8x16_t __attribute__((vector_size(16)));
|
||||
|
||||
typedef int8_t int8x16_t __attribute__((vector_size(16)));
|
||||
typedef int16_t int16x8_t __attribute__((vector_size(16)));
|
||||
typedef int32_t int32x4_t __attribute__((vector_size(16)));
|
||||
|
||||
typedef uint8_t uint8x16_t __attribute__((vector_size(16)));
|
||||
typedef uint16_t uint16x8_t __attribute__((vector_size(16)));
|
||||
typedef uint32_t uint32x4_t __attribute__((vector_size(16)));
|
||||
|
||||
typedef float float32x4_t __attribute__((vector_size(16)));
|
||||
typedef double double64x2_t __attribute((vector_size(16)));
|
||||
|
||||
typedef signed long long long64x2_t __attribute((vector_size(16)));
|
||||
typedef unsigned long long ulong64x2_t __attribute__((vector_size(16)));
|
||||
|
||||
typedef struct ggml_uint8x16x2_t {
|
||||
uint8x16_t val[2];
|
||||
} ggml_uint8x16x2_t;
|
||||
|
||||
inline static ggml_uint8x16x2_t ggml_vec_xl_u8x2(const uint8_t * ptr) {
|
||||
ggml_uint8x16x2_t res;
|
||||
|
||||
res.val[0] = vec_xl( 0, ptr);
|
||||
res.val[1] = vec_xl(16, ptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
typedef struct ggml_uint8x16x4_t {
|
||||
uint8x16_t val[4];
|
||||
} ggml_uint8x16x4_t;
|
||||
|
||||
inline static ggml_uint8x16x4_t ggml_vec_xl_u8x4(const uint8_t * ptr) {
|
||||
ggml_uint8x16x4_t res;
|
||||
|
||||
res.val[0] = vec_xl( 0, ptr);
|
||||
res.val[1] = vec_xl(16, ptr);
|
||||
res.val[2] = vec_xl(32, ptr);
|
||||
res.val[3] = vec_xl(48, ptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
typedef struct ggml_int8x16x4_t {
|
||||
int8x16_t val[4];
|
||||
} ggml_int8x16x4_t;
|
||||
|
||||
inline static ggml_int8x16x4_t ggml_vec_xl_s8x4(const int8_t * ptr) {
|
||||
ggml_int8x16x4_t res;
|
||||
|
||||
res.val[0] = vec_xl( 0, ptr);
|
||||
res.val[1] = vec_xl(16, ptr);
|
||||
res.val[2] = vec_xl(32, ptr);
|
||||
res.val[3] = vec_xl(48, ptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
typedef struct ggml_int16x8x2_t {
|
||||
int16x8_t val[2];
|
||||
} ggml_int16x8x2_t;
|
||||
|
||||
inline static ggml_int16x8x2_t ggml_vec_xl_s16x2(const int16_t * ptr) {
|
||||
ggml_int16x8x2_t res;
|
||||
|
||||
res.val[0] = vec_xl( 0, ptr);
|
||||
res.val[1] = vec_xl(16, ptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
/*
|
||||
! WARNING: Very slow. Use vec_perm if possible. Refer to iq4_xs
|
||||
! or iq4_nl for example implementation.
|
||||
*/
|
||||
inline static int8x16_t ggml_vec_tbl(int8x16_t a, uint8x16_t b) {
|
||||
int8x16_t res;
|
||||
|
||||
res[ 0] = a[b[ 0]];
|
||||
res[ 1] = a[b[ 1]];
|
||||
res[ 2] = a[b[ 2]];
|
||||
res[ 3] = a[b[ 3]];
|
||||
res[ 4] = a[b[ 4]];
|
||||
res[ 5] = a[b[ 5]];
|
||||
res[ 6] = a[b[ 6]];
|
||||
res[ 7] = a[b[ 7]];
|
||||
res[ 8] = a[b[ 8]];
|
||||
res[ 9] = a[b[ 9]];
|
||||
res[10] = a[b[10]];
|
||||
res[11] = a[b[11]];
|
||||
res[12] = a[b[12]];
|
||||
res[13] = a[b[13]];
|
||||
res[14] = a[b[14]];
|
||||
res[15] = a[b[15]];
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) {
|
||||
const uchar8x16_t v_maske = { 0, 1, 4, 5, 8, 9, 12, 13,
|
||||
16, 17, 20, 21, 24, 25, 28, 29 };
|
||||
|
||||
const int16x8_t v_abo = vec_pack((int32x4_t)a, (int32x4_t)b);
|
||||
const int16x8_t v_abe = vec_perm(a, b, v_maske);
|
||||
return v_abo + v_abe;
|
||||
}
|
||||
|
||||
inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
||||
const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b);
|
||||
return acc + (vec_unpackh(p) + vec_unpackl(p));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#if defined(__loongarch_asx)
|
||||
/* float type data load instructions */
|
||||
static __m128 __lsx_vreplfr2vr_s(const float val) {
|
||||
|
||||
@@ -1011,6 +1011,38 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
|
||||
__lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
|
||||
|
||||
}
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
__vector float srcv [8];
|
||||
__vector float asrcv[8];
|
||||
__vector float amaxv[8];
|
||||
|
||||
for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
|
||||
for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
|
||||
for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
|
||||
for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
|
||||
for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
|
||||
|
||||
const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
|
||||
vec_extract(amaxv[0], 1)),
|
||||
MAX(vec_extract(amaxv[0], 2),
|
||||
vec_extract(amaxv[0], 3)));
|
||||
|
||||
const float d = amax / ((1 << 7) - 1);
|
||||
const float id = d ? 1.0f / d : 0.0f;
|
||||
|
||||
y[i].d = GGML_FP32_TO_FP16(d);
|
||||
|
||||
for (int j = 0; j < 8; j++) {
|
||||
const __vector float v = vec_mul(srcv[j], vec_splats(id));
|
||||
const __vector int32_t vi = vec_signed(v);
|
||||
|
||||
y[i].qs[4*j + 0] = vec_extract(vi, 0);
|
||||
y[i].qs[4*j + 1] = vec_extract(vi, 1);
|
||||
y[i].qs[4*j + 2] = vec_extract(vi, 2);
|
||||
y[i].qs[4*j + 3] = vec_extract(vi, 3);
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(nb);
|
||||
// scalar
|
||||
@@ -1337,6 +1369,44 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
|
||||
__lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0);
|
||||
__lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
|
||||
}
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
__vector float srcv [8];
|
||||
__vector float asrcv[8];
|
||||
__vector float amaxv[8];
|
||||
|
||||
for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
|
||||
for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
|
||||
for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
|
||||
for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
|
||||
for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
|
||||
|
||||
const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
|
||||
vec_extract(amaxv[0], 1)),
|
||||
MAX(vec_extract(amaxv[0], 2),
|
||||
vec_extract(amaxv[0], 3)));
|
||||
|
||||
const float d = amax / ((1 << 7) - 1);
|
||||
const float id = d ? 1.0f / d : 0.0f;
|
||||
|
||||
y[i].d = GGML_FP32_TO_FP16(d);
|
||||
|
||||
__vector int32_t acc = vec_splats(0);
|
||||
|
||||
for (int j = 0; j < 8; j++) {
|
||||
const __vector float v = vec_mul(srcv[j], vec_splats(id));
|
||||
const __vector int32_t vi = vec_signed(v);
|
||||
|
||||
y[i].qs[4*j + 0] = vec_extract(vi, 0);
|
||||
y[i].qs[4*j + 1] = vec_extract(vi, 1);
|
||||
y[i].qs[4*j + 2] = vec_extract(vi, 2);
|
||||
y[i].qs[4*j + 3] = vec_extract(vi, 3);
|
||||
|
||||
acc = vec_add(acc, vi);
|
||||
}
|
||||
|
||||
y[i].s = GGML_FP32_TO_FP16(d * (acc[0] + acc[1] + acc[2] + acc[3]));
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(nb);
|
||||
// scalar
|
||||
@@ -2488,6 +2558,37 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||
}
|
||||
|
||||
sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
__vector float acc = vec_splats(0.0f);
|
||||
|
||||
const __vector uint8_t v_m = vec_splats((const uint8_t)0x0F);
|
||||
const __vector int8_t v_s = vec_splats( (const int8_t)0x08);
|
||||
|
||||
for (; ib < nb; ++ib) {
|
||||
const __vector uint8_t v_x = vec_xl(0, x[ib].qs);
|
||||
const __vector int8_t v_xl = (const __vector int8_t)(v_x & v_m);
|
||||
const __vector int8_t v_xh = (const __vector int8_t)(v_x >> 4);
|
||||
|
||||
const __vector int8_t v_xls = vec_sub(v_xl, v_s);
|
||||
const __vector int8_t v_xhs = vec_sub(v_xh, v_s);
|
||||
|
||||
const __vector int8_t v_yl = vec_xl(0 , y[ib].qs);
|
||||
const __vector int8_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
|
||||
|
||||
const __vector int16_t v_xylso = vec_mulo(v_xls, v_yl);
|
||||
const __vector int16_t v_xylse = vec_mule(v_xls, v_yl);
|
||||
const __vector int16_t v_xyhso = vec_mulo(v_xhs, v_yh);
|
||||
const __vector int16_t v_xyhse = vec_mule(v_xhs, v_yh);
|
||||
|
||||
__vector int16_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_);
|
||||
|
||||
const __vector float v_xy = vec_float(vec_unpackh(v_xy_));
|
||||
const __vector float v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
|
||||
|
||||
acc = vec_madd(v_xy, v_d, acc);
|
||||
}
|
||||
|
||||
sumf = acc[0] + acc[1] + acc[2] + acc[3];
|
||||
#endif
|
||||
for (; ib < nb; ++ib) {
|
||||
int sumi0 = 0;
|
||||
@@ -2781,6 +2882,35 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
||||
}
|
||||
|
||||
sumf = hsum_float_8(acc) + summs;
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
float summs = 0;
|
||||
float32x4_t acc = vec_splats(0.0f);
|
||||
|
||||
const uint8x16_t v_m = vec_splat_u8(0x0F);
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (; ib < nb; ++ib) {
|
||||
__builtin_prefetch(x[ib].qs, 0, 1);
|
||||
__builtin_prefetch(y[ib].qs, 0, 1);
|
||||
|
||||
summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
|
||||
|
||||
const uint8x16_t v_x = vec_xl(0, x[ib].qs);
|
||||
const int8x16_t v_xl = (const int8x16_t)(v_x & v_m);
|
||||
const int8x16_t v_xh = (const int8x16_t)(v_x >> 4);
|
||||
|
||||
const int8x16_t v_yl = vec_xl(0 , y[ib].qs);
|
||||
const int8x16_t v_yh = vec_xl(QK8_1/2, y[ib].qs);
|
||||
|
||||
const int32x4_t v_xy_ = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
|
||||
const float32x4_t v_xy = vec_float(v_xy_);
|
||||
|
||||
const float32x4_t v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
|
||||
|
||||
acc = vec_madd(v_xy, v_d, acc);
|
||||
}
|
||||
|
||||
sumf = acc[0] + acc[1] + acc[2] + acc[3] + summs;
|
||||
#endif
|
||||
for (; ib < nb; ++ib) {
|
||||
int sumi0 = 0;
|
||||
@@ -3915,6 +4045,27 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||
}
|
||||
|
||||
sumf = hsum_float_8(acc);
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
__vector float acc = vec_splats(0.0f);
|
||||
|
||||
#pragma GCC unroll 8
|
||||
for (; ib < nb; ++ib) {
|
||||
__builtin_prefetch(x[ib].qs, 0, 1);
|
||||
__builtin_prefetch(y[ib].qs, 0, 1);
|
||||
|
||||
const int8x16_t v_xl = vec_xl(0 , x[ib].qs);
|
||||
const int8x16_t v_xh = vec_xl(QK8_0/2, x[ib].qs);
|
||||
const int8x16_t v_yl = vec_xl(0 , y[ib].qs);
|
||||
const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
|
||||
|
||||
const int32x4_t v_xy_ = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
|
||||
const float32x4_t v_xy = vec_float(v_xy_);
|
||||
const float32x4_t v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
|
||||
|
||||
acc = vec_madd(v_xy, v_d, acc);
|
||||
}
|
||||
|
||||
sumf = acc[0] + acc[1] + acc[2] + acc[3];
|
||||
#endif
|
||||
for (; ib < nb; ++ib) {
|
||||
int sumi = 0;
|
||||
@@ -5112,7 +5263,182 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
#ifdef __ARM_NEON
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
|
||||
uint32_t utmp[4];
|
||||
|
||||
const int8_t m32 = 32;
|
||||
const int vector_length = svcntb()*8;
|
||||
const svuint8_t m3b_sv = svdup_n_u8(0x3);
|
||||
const svint32_t vzero_sv = svdup_n_s32(0);
|
||||
|
||||
const svuint8_t m0_sv = svdup_n_u8(1);
|
||||
const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1);
|
||||
const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2);
|
||||
const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3);
|
||||
svbool_t pred_s32 = svnot_b_z (svptrue_b32(), svptrue_pat_b32(SV_VL4));
|
||||
|
||||
float sum = 0;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
||||
|
||||
const uint8_t * restrict q3_sv = x[i].qs;
|
||||
const uint8_t * restrict qh_sv = x[i].hmask;
|
||||
const int8_t * restrict q8_sv = y[i].qs;
|
||||
|
||||
// Set up scales
|
||||
uint32_t * aux = &x[i].scales;
|
||||
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
|
||||
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
|
||||
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
|
||||
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
|
||||
|
||||
int8_t * scale = (int8_t *)utmp;
|
||||
|
||||
for (int j = 0; j < 16; ++j) scale[j] -= m32;
|
||||
|
||||
switch (vector_length) {
|
||||
case 128:
|
||||
{
|
||||
svuint8_t qhbits_sv_1 = svld1_u8(svptrue_b8(), qh_sv);
|
||||
svuint8_t qhbits_sv_2 = svld1_u8(svptrue_b8(), qh_sv+16);
|
||||
svuint8_t q3h_sv;
|
||||
|
||||
svint32_t sumi1_1 = svdup_n_s32(0);
|
||||
svint8_t q3bytes_sv;
|
||||
|
||||
for (int j = 0; j < QK_K/128; ++j) {
|
||||
|
||||
const svuint8_t q3bits_sv = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
|
||||
const svuint8_t q3bits_sv_1 = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
|
||||
svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
||||
svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
||||
|
||||
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_1), 2);
|
||||
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||
|
||||
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
|
||||
|
||||
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_2), 2);
|
||||
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv_1, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||
|
||||
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
|
||||
|
||||
q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
||||
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
||||
|
||||
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_1), 1);
|
||||
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||
|
||||
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
|
||||
|
||||
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_2), 1);
|
||||
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||
|
||||
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
|
||||
|
||||
|
||||
scale += 4;
|
||||
q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
||||
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
||||
|
||||
q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_1);
|
||||
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||
|
||||
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
|
||||
|
||||
q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_2);
|
||||
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||
|
||||
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
|
||||
|
||||
|
||||
q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
||||
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
||||
|
||||
q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_1), 1);
|
||||
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||
|
||||
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
|
||||
|
||||
q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_2), 1);
|
||||
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||
|
||||
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
|
||||
|
||||
if (j == 0) {
|
||||
qhbits_sv_1 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_1, 4);
|
||||
qhbits_sv_2 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_2, 4);
|
||||
}
|
||||
|
||||
scale += 4;
|
||||
}
|
||||
|
||||
sum += d * (svaddv_s32(svptrue_b32(), sumi1_1));
|
||||
} break;
|
||||
case 256:
|
||||
case 512:
|
||||
{
|
||||
svuint8_t qhbits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), qh_sv);
|
||||
svuint8_t q3h_sv;
|
||||
|
||||
svint32_t sumi1_1 = svdup_n_s32(0);
|
||||
svint8_t q3bytes_sv;
|
||||
|
||||
for (int j = 0; j < QK_K/128; ++j) {
|
||||
|
||||
const svuint8_t q3bits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), q3_sv); q3_sv += 32;
|
||||
svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
||||
svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
||||
|
||||
q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m0_sv, qhbits_sv), 2);
|
||||
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||
|
||||
|
||||
svint32_t scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
|
||||
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
|
||||
|
||||
q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m1_sv, qhbits_sv), 1);
|
||||
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||
|
||||
scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
|
||||
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
|
||||
|
||||
scale += 4;
|
||||
q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
||||
q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
||||
|
||||
q3h_sv = svbic_u8_x(svptrue_pat_b8(SV_VL32), m2_sv, qhbits_sv);
|
||||
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||
|
||||
scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
|
||||
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
|
||||
|
||||
q3h_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m3_sv, qhbits_sv), 1);
|
||||
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||
|
||||
scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
|
||||
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
|
||||
|
||||
if (j == 0) {
|
||||
qhbits_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), qhbits_sv, 4);
|
||||
}
|
||||
|
||||
scale += 4;
|
||||
}
|
||||
|
||||
sum += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), sumi1_1));
|
||||
} break;
|
||||
default:
|
||||
assert(false && "Unsupported vector length");
|
||||
break;
|
||||
}
|
||||
}
|
||||
*s = sum;
|
||||
|
||||
#elif __ARM_NEON
|
||||
|
||||
uint32_t aux[3];
|
||||
uint32_t utmp[4];
|
||||
@@ -6622,6 +6948,77 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||
|
||||
|
||||
*s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
const uint8x16_t v_lm = vec_splat_u8(0x0F);
|
||||
const int32x4_t v_z = vec_splat_s32(0);
|
||||
|
||||
uint8x16_t v_x[2];
|
||||
int8x16_t v_xl[2];
|
||||
int8x16_t v_y[2];
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
||||
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
||||
|
||||
const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
|
||||
const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
|
||||
const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh);
|
||||
|
||||
memcpy(utmp, x[i].scales, 12);
|
||||
|
||||
uint32x4_t v_mins8 = { 0 };
|
||||
v_mins8 = vec_insert(utmp[1] & kmask1, v_mins8, 0);
|
||||
v_mins8 = vec_insert(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), v_mins8, 1);
|
||||
|
||||
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
||||
utmp[0] &= kmask1;
|
||||
|
||||
const int16x8_t v_minsh = (int16x8_t)vec_unpackh((uint8x16_t)v_mins8);
|
||||
|
||||
const int32x4_t v_minso = vec_mulo(v_ysums, v_minsh);
|
||||
const int32x4_t v_minse = vec_mule(v_ysums, v_minsh);
|
||||
const int32x4_t v_mins = v_minso + v_minse;
|
||||
sumf -= dmin * (v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]);
|
||||
|
||||
const uint8_t * scales = (const uint8_t *)utmp;
|
||||
const uint8_t * restrict x0 = x[i].qs;
|
||||
const int8_t * restrict y0 = y[i].qs;
|
||||
|
||||
int32_t sumi1 = 0;
|
||||
int32_t sumi2 = 0;
|
||||
|
||||
for (int j = 0; j < QK_K/64; ++j) {
|
||||
v_x[0] = vec_xl(0 , x0);
|
||||
v_x[1] = vec_xl(16, x0);
|
||||
x0 += 32;
|
||||
|
||||
v_y[0] = vec_xl(0 , y0);
|
||||
v_y[1] = vec_xl(16, y0);
|
||||
y0 += 32;
|
||||
|
||||
v_xl[0] = (int8x16_t)vec_and(v_x[0], v_lm);
|
||||
v_xl[1] = (int8x16_t)vec_and(v_x[1], v_lm);
|
||||
|
||||
const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);
|
||||
sumi1 += (p1[0] + p1[1] + p1[2] + p1[3]) * scales[2*j+0];
|
||||
|
||||
v_y[0] = vec_xl(0 , y0);
|
||||
v_y[1] = vec_xl(16, y0);
|
||||
y0 += 32;
|
||||
|
||||
v_xl[0] = (int8x16_t)vec_sr(v_x[0], 4);
|
||||
v_xl[1] = (int8x16_t)vec_sr(v_x[1], 4);
|
||||
|
||||
const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);
|
||||
sumi2 += (p2[0] + p2[1] + p2[2] + p2[3]) * scales[2*j+1];
|
||||
}
|
||||
|
||||
sumf += d * (sumi1 + sumi2);
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
#else
|
||||
|
||||
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
||||
@@ -7351,7 +7748,94 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||
acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4));
|
||||
|
||||
*s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
const uint8x16_t v_lm = vec_splat_u8(0x0F);
|
||||
const uint8x16_t v_1m = vec_splat_u8(0x01);
|
||||
const uint8x16_t v_2m = vec_splat_u8(0x02);
|
||||
|
||||
const int32x4_t v_z = vec_splat_s32(0);
|
||||
|
||||
const uchar8x16_t v_minsm = {
|
||||
0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
|
||||
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF
|
||||
};
|
||||
|
||||
int8x16_t q5b[4];
|
||||
uint8x16_t q5h[4];
|
||||
|
||||
uint8x16_t v_xl[2];
|
||||
uint8x16_t v_xh[2];
|
||||
int8x16_t v_y[4];
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
||||
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
||||
|
||||
const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
|
||||
const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
|
||||
const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh);
|
||||
|
||||
memcpy(utmp, x[i].scales, 12);
|
||||
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux = utmp[1] & kmask1;
|
||||
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
||||
utmp[2] = uaux;
|
||||
utmp[0] &= kmask1;
|
||||
|
||||
const uint8x16_t v_mins16 = vec_xl(0, (const uint8_t *)utmp);
|
||||
const uint8x16_t v_mins8 = vec_perm(v_mins16, v_mins16, v_minsm);
|
||||
const int16x8_t v_minsh = (int16x8_t)vec_unpackh(v_mins8);
|
||||
|
||||
const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh);
|
||||
const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh);
|
||||
const int32x4_t v_mins = vec_add(v_minsho, v_minshe);
|
||||
const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3];
|
||||
|
||||
const uint8_t * scales = (const uint8_t *)utmp;
|
||||
const uint8_t * restrict x0l = x[i].qs;
|
||||
const uint8_t * restrict x0h = x[i].qh;
|
||||
const int8_t * restrict y0 = y[i].qs;
|
||||
|
||||
v_xh[0] = vec_xl(0 , x0h);
|
||||
v_xh[1] = vec_xl(16, x0h);
|
||||
|
||||
int32_t sumi = 0;
|
||||
for (int j = 0; j < QK_K/64; ++j) {
|
||||
v_xl[0] = vec_xl(0 , x0l);
|
||||
v_xl[1] = vec_xl(16, x0l);
|
||||
x0l += 32;
|
||||
|
||||
v_y[0] = vec_xl(0 , y0);
|
||||
v_y[1] = vec_xl(16, y0);
|
||||
v_y[2] = vec_xl(32, y0);
|
||||
v_y[3] = vec_xl(48, y0);
|
||||
y0 += 64;
|
||||
|
||||
q5h[0] = vec_sl(vec_and(v_1m, v_xh[0]), 4);
|
||||
q5h[1] = vec_sl(vec_and(v_1m, v_xh[1]), 4);
|
||||
q5h[2] = vec_sl(vec_and(v_2m, v_xh[0]), 3);
|
||||
q5h[3] = vec_sl(vec_and(v_2m, v_xh[1]), 3);
|
||||
v_xh[0] = vec_sr(v_xh[0], 2);
|
||||
v_xh[1] = vec_sr(v_xh[1], 2);
|
||||
|
||||
q5b[0] = (int8x16_t)vec_or(vec_and(v_xl[0], v_lm), q5h[0]);
|
||||
q5b[1] = (int8x16_t)vec_or(vec_and(v_xl[1], v_lm), q5h[1]);
|
||||
q5b[2] = (int8x16_t)vec_or(vec_sr(v_xl[0], 4), q5h[2]);
|
||||
q5b[3] = (int8x16_t)vec_or(vec_sr(v_xl[1], 4), q5h[3]);
|
||||
|
||||
int32x4_t sumi0 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[0], v_y[0]), q5b[1], v_y[1]);
|
||||
int32x4_t sumi1 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[2], v_y[2]), q5b[3], v_y[3]);
|
||||
|
||||
sumi += (sumi0[0] + sumi0[1] + sumi0[2] + sumi0[3]) * *scales++;
|
||||
sumi += (sumi1[0] + sumi1[1] + sumi1[2] + sumi1[3]) * *scales++;
|
||||
}
|
||||
|
||||
sumf += d * sumi - dmin * mins;
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
#else
|
||||
|
||||
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
||||
@@ -8068,7 +8552,130 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc);
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
float sum = 0;
|
||||
|
||||
// Lower 4-bit and upper 2-bit masks
|
||||
const uint8x16_t v_lm = vec_splat_u8(0x0F);
|
||||
const uint8x16_t v_um = vec_splat_u8(0x03);
|
||||
|
||||
const int32x4_t v_z = vec_splat_s32(0);
|
||||
|
||||
int8x16_t q6b[4];
|
||||
uint8x16_t q6h[4];
|
||||
|
||||
uint8x16_t v_xl[4];
|
||||
uint8x16_t v_xh[2];
|
||||
int8x16_t v_y[4];
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float d_all = GGML_FP16_TO_FP32(x[i].d);
|
||||
|
||||
const uint8_t * restrict x0l = x[i].ql;
|
||||
const uint8_t * restrict x0h = x[i].qh;
|
||||
const int8_t * restrict y0 = y[i].qs;
|
||||
|
||||
const int8_t * restrict scale = x[i].scales;
|
||||
|
||||
const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
|
||||
const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
|
||||
|
||||
const int8x16_t v_scale = vec_xl(0, scale);
|
||||
const int16x8_t v_scalel = vec_unpackh(v_scale);
|
||||
const int16x8_t v_scaleh = vec_unpackl(v_scale);
|
||||
|
||||
const int32x4_t v_minslo = vec_mulo(v_ysumsl, v_scalel);
|
||||
const int32x4_t v_minsle = vec_mule(v_ysumsl, v_scalel);
|
||||
const int32x4_t v_minsho = vec_mulo(v_ysumsh, v_scaleh);
|
||||
const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh);
|
||||
const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe;
|
||||
|
||||
const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3];
|
||||
|
||||
int32_t isum = 0;
|
||||
for (int j = 0; j < QK_K/128; ++j) {
|
||||
// Load model upper 2 bits
|
||||
v_xh[0] = vec_xl(0 , x0h);
|
||||
v_xh[1] = vec_xl(16, x0h);
|
||||
x0h += 32;
|
||||
|
||||
// Load model lower 4 bits
|
||||
v_xl[0] = vec_xl(0 , x0l);
|
||||
v_xl[1] = vec_xl(16, x0l);
|
||||
v_xl[2] = vec_xl(32, x0l);
|
||||
v_xl[3] = vec_xl(48, x0l);
|
||||
x0l += 64;
|
||||
|
||||
// Load activation quants
|
||||
v_y[0] = vec_xl(0 , y0);
|
||||
v_y[1] = vec_xl(16, y0);
|
||||
v_y[2] = vec_xl(32, y0);
|
||||
v_y[3] = vec_xl(48, y0);
|
||||
y0 += 64;
|
||||
|
||||
q6h[0] = vec_sl(vec_and(v_um, v_xh[0]), 4);
|
||||
q6h[1] = vec_sl(vec_and(v_um, v_xh[1]), 4);
|
||||
uint8x16_t shifted = vec_sr(v_xh[0], 2);
|
||||
q6h[2] = vec_sl(vec_and(v_um, shifted), 4);
|
||||
shifted = vec_sr(v_xh[1], 2);
|
||||
q6h[3] = vec_sl(vec_and(v_um, shifted), 4);
|
||||
|
||||
q6b[0] = (int8x16_t)(vec_or(vec_and(v_xl[0], v_lm), q6h[0]));
|
||||
q6b[1] = (int8x16_t)(vec_or(vec_and(v_xl[1], v_lm), q6h[1]));
|
||||
q6b[2] = (int8x16_t)(vec_or(vec_and(v_xl[2], v_lm), q6h[2]));
|
||||
q6b[3] = (int8x16_t)(vec_or(vec_and(v_xl[3], v_lm), q6h[3]));
|
||||
|
||||
int32x4_t summs0 = ggml_vec_dot(v_z, q6b[0], v_y[0]);
|
||||
int32x4_t summs1 = ggml_vec_dot(v_z, q6b[1], v_y[1]);
|
||||
int32x4_t summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]);
|
||||
int32x4_t summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]);
|
||||
|
||||
isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] +
|
||||
(summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] +
|
||||
(summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] +
|
||||
(summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3];
|
||||
|
||||
scale += 4;
|
||||
|
||||
|
||||
// Load activation quants
|
||||
v_y[0] = vec_xl(0 , y0);
|
||||
v_y[1] = vec_xl(16, y0);
|
||||
v_y[2] = vec_xl(32, y0);
|
||||
v_y[3] = vec_xl(48, y0);
|
||||
y0 += 64;
|
||||
|
||||
shifted = vec_sr(v_xh[0], 4);
|
||||
q6h[0] = vec_sl(vec_and(v_um, shifted), 4);
|
||||
shifted = vec_sr(v_xh[1], 4);
|
||||
q6h[1] = vec_sl(vec_and(v_um, shifted), 4);
|
||||
shifted = vec_sr(v_xh[0], 6);
|
||||
q6h[2] = vec_sl(vec_and(v_um, shifted), 4);
|
||||
shifted = vec_sr(v_xh[1], 6);
|
||||
q6h[3] = vec_sl(vec_and(v_um, shifted), 4);
|
||||
|
||||
q6b[0] = (int8x16_t)(vec_or(vec_sr(v_xl[0], 4), q6h[0]));
|
||||
q6b[1] = (int8x16_t)(vec_or(vec_sr(v_xl[1], 4), q6h[1]));
|
||||
q6b[2] = (int8x16_t)(vec_or(vec_sr(v_xl[2], 4), q6h[2]));
|
||||
q6b[3] = (int8x16_t)(vec_or(vec_sr(v_xl[3], 4), q6h[3]));
|
||||
|
||||
summs0 = ggml_vec_dot(v_z, q6b[0], v_y[0]);
|
||||
summs1 = ggml_vec_dot(v_z, q6b[1], v_y[1]);
|
||||
summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]);
|
||||
summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]);
|
||||
|
||||
isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] +
|
||||
(summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] +
|
||||
(summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] +
|
||||
(summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3];
|
||||
|
||||
scale += 4;
|
||||
}
|
||||
|
||||
sum += d_all * y[i].d * (isum - 32 * mins);
|
||||
}
|
||||
|
||||
*s = sum;
|
||||
#else
|
||||
|
||||
int8_t aux8[QK_K];
|
||||
@@ -8429,7 +9036,57 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
|
||||
}
|
||||
|
||||
*s = 0.125f * hsum_float_8(accumf);
|
||||
|
||||
//#elif defined(__VXE__) || defined(__VXE2__)
|
||||
// const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
||||
//
|
||||
// uint32_t aux32[4];
|
||||
// const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||
//
|
||||
// float sumf = 0;
|
||||
//
|
||||
// for (int i = 0; i < nb; ++i) {
|
||||
// const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||
// const uint16_t * restrict q2 = x[i].qs;
|
||||
// const int8_t * restrict q8 = y[i].qs;
|
||||
//
|
||||
// float sumf1 = 0, sumf2 = 0;
|
||||
//
|
||||
// for (int ib32 = 0; ib32 < QK_K/32; ib += 2) {
|
||||
// int8x16_t q8b0 = vec_xl( 0, q8);
|
||||
// int8x16_t qb81 = vec_xl(16, q8);
|
||||
// int8x16_t q8b2 = vec_xl(32, q8);
|
||||
// int8x16_t q8b3 = vec_xl(48, q8);
|
||||
// q8 += 64;
|
||||
//
|
||||
// memcpy(aux32, q2, 4 * sizeof(uint32_t));
|
||||
// q2 += 8;
|
||||
//
|
||||
// int8x16_t q2u0 = { *(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1]) };
|
||||
// int8x16_t q2u1 = { *(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3]) };
|
||||
// int8x16_t q2u2 = { *(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9]) };
|
||||
// int8x16_t q2u3 = { *(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11]) };
|
||||
//
|
||||
// int8x16_t q2s0 = { *(const int64_t *)(signs64 + ((aux32[1] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 7) & 127)) };
|
||||
// int8x16_t q2s1 = { *(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127)) };
|
||||
// int8x16_t q2s2 = { *(const int64_t *)(signs64 + ((aux32[3] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 7) & 127)) };
|
||||
// int8x16_t q2s3 = { *(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127)) };
|
||||
//
|
||||
// q2u0 = vec_mul(q2u0, q2s0);
|
||||
// q2u1 = vec_mul(q2u1, q2s1);
|
||||
// q2u2 = vec_mul(q2u2, q2s2);
|
||||
// q2u3 = vec_mul(q2u3, q2s3);
|
||||
//
|
||||
// const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(vec_splat_s32(0), q2u0, q8b0), q2u1, q8b1);
|
||||
// const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(vec_splat_s32(0), q2u2, q8b2), q2u3, q8b3);
|
||||
//
|
||||
// sumf1 += (p1[0] + p1[1] + p1[2] + p1[3]) * (0.5f + (aux32[1] >> 28));
|
||||
// sumf2 += (p2[0] + p2[1] + p2[2] + p2[3]) * (0.5f + (aux32[3] >> 28));
|
||||
// }
|
||||
//
|
||||
// sumf += d * (sumf1 + sumf2);
|
||||
// }
|
||||
//
|
||||
// *s = 0.25f * sumf;
|
||||
#else
|
||||
|
||||
uint32_t aux32[2];
|
||||
@@ -11190,6 +11847,27 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||
|
||||
sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2));
|
||||
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
const int8x16_t v_k = vec_xl(0, kvalues_iq4nl);
|
||||
const uint8x16_t v_m = vec_splat_u8(0x0F);
|
||||
|
||||
for (; ib < nb; ++ib) {
|
||||
const block_iq4_nl * restrict x0 = &x[ib];
|
||||
const block_q8_0 * restrict y0 = &y[ib];
|
||||
|
||||
const uint8x16_t v_x = vec_xl(0, x0->qs);
|
||||
int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
|
||||
int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);
|
||||
|
||||
v_xl = vec_perm(v_k, v_k, (uchar8x16_t)v_xl);
|
||||
v_xh = vec_perm(v_k, v_k, (uchar8x16_t)v_xh);
|
||||
|
||||
const int8x16_t v_yl = vec_xl(0 , y0->qs);
|
||||
const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs);
|
||||
const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
|
||||
|
||||
sumf += GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d) * (v_xy[0] + v_xy[1] + v_xy[2] + v_xy[3]);
|
||||
}
|
||||
#endif
|
||||
for (; ib < nb; ++ib) {
|
||||
const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d);
|
||||
@@ -11468,6 +12146,56 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
|
||||
}
|
||||
|
||||
*s = hsum_float_8(accum);
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
const int8x16_t v_k = vec_xl(0, kvalues_iq4nl);
|
||||
const uint8x16_t v_m = vec_splat_u8(0x0F);
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
for (int ibl = 0; ibl < nb; ++ibl) {
|
||||
const uint8_t * restrict q4 = x[ibl].qs;
|
||||
const int8_t * restrict q8 = y[ibl].qs;
|
||||
|
||||
uint16_t h = x[ibl].scales_h;
|
||||
|
||||
int sumi1 = 0, sumi2 = 0;
|
||||
for (int ib = 0; ib < QK_K/64; ++ib) {
|
||||
const uint8x16_t v_x0 = vec_xl(0 , q4);
|
||||
const uint8x16_t v_x1 = vec_xl(QK4_NL/2, q4);
|
||||
q4 += 32;
|
||||
|
||||
int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);
|
||||
int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);
|
||||
int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);
|
||||
int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);
|
||||
|
||||
v_x0l = vec_perm(v_k, v_k, (uchar8x16_t)v_x0l);
|
||||
v_x0h = vec_perm(v_k, v_k, (uchar8x16_t)v_x0h);
|
||||
v_x1l = vec_perm(v_k, v_k, (uchar8x16_t)v_x1l);
|
||||
v_x1h = vec_perm(v_k, v_k, (uchar8x16_t)v_x1h);
|
||||
|
||||
const int8x16_t v_y0 = vec_xl( 0, q8);
|
||||
const int8x16_t v_y1 = vec_xl(16, q8);
|
||||
const int8x16_t v_y2 = vec_xl(32, q8);
|
||||
const int8x16_t v_y3 = vec_xl(48, q8);
|
||||
q8 += 64;
|
||||
|
||||
int32x4_t vsumi0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0l, v_y0), v_x0h, v_y1);
|
||||
int32x4_t vsumi1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1l, v_y2), v_x1h, v_y3);
|
||||
|
||||
int ls1 = ((x[ibl].scales_l[ib] & 0xF) | ((h << 4) & 0x30)) - 32;
|
||||
int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32;
|
||||
|
||||
h >>= 4;
|
||||
|
||||
sumi1 += (vsumi0[0] + vsumi0[1] + vsumi0[2] + vsumi0[3]) * ls1;
|
||||
sumi2 += (vsumi1[0] + vsumi1[1] + vsumi1[2] + vsumi1[3]) * ls2;
|
||||
}
|
||||
|
||||
sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
|
||||
#else
|
||||
float sumf = 0;
|
||||
|
||||
@@ -112,7 +112,8 @@ struct ggml_arm_arch_features_type {
|
||||
int has_i8mm;
|
||||
int has_sve;
|
||||
int sve_cnt;
|
||||
} ggml_arm_arch_features = {-1, -1, -1, -1, 0};
|
||||
int has_sme;
|
||||
} ggml_arm_arch_features = {-1, -1, -1, -1, 0, -1};
|
||||
#endif
|
||||
|
||||
|
||||
@@ -236,6 +237,8 @@ typedef pthread_t ggml_thread_t;
|
||||
#else
|
||||
#if defined(__POWER9_VECTOR__)
|
||||
#define CACHE_LINE_SIZE 128
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
#define CACHE_LINE_SIZE 256
|
||||
#else
|
||||
#define CACHE_LINE_SIZE 64
|
||||
#endif
|
||||
@@ -1210,6 +1213,87 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
|
||||
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
|
||||
#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
|
||||
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
|
||||
#define GGML_SIMD
|
||||
|
||||
// F32 s390x
|
||||
|
||||
#define GGML_F32_STEP 32
|
||||
#define GGML_F32_EPR 4
|
||||
|
||||
#define GGML_F32x4 __vector float
|
||||
#define GGML_F32x4_ZERO vec_splats(0.0f)
|
||||
#define GGML_F32x4_SET1 vec_splats
|
||||
#define GGML_F32x4_LOAD(p) vec_xl(0, p)
|
||||
#define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p)
|
||||
#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
|
||||
#define GGML_F32x4_ADD vec_add
|
||||
#define GGML_F32x4_MUL vec_mul
|
||||
#define GGML_F32x4_REDUCE(res, x) \
|
||||
{ \
|
||||
int offset = GGML_F32_ARR >> 1; \
|
||||
for (int i = 0; i < offset; ++i) { \
|
||||
x[i] = vec_add(x[i], x[offset + i]); \
|
||||
} \
|
||||
offset >>= 1; \
|
||||
for (int i = 0; i < offset; ++i) { \
|
||||
x[i] = vec_add(x[i], x[offset + i]); \
|
||||
} \
|
||||
offset >>= 1; \
|
||||
for (int i = 0; i < offset; ++i) { \
|
||||
x[i] = vec_add(x[i], x[offset + i]); \
|
||||
} \
|
||||
res = vec_extract(x[0], 0) + \
|
||||
vec_extract(x[0], 1) + \
|
||||
vec_extract(x[0], 2) + \
|
||||
vec_extract(x[0], 3); \
|
||||
}
|
||||
|
||||
#define GGML_F32_VEC GGML_F32x4
|
||||
#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
|
||||
#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
|
||||
#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
|
||||
#define GGML_F32_VEC_STORE GGML_F32x4_STORE
|
||||
#define GGML_F32_VEC_FMA GGML_F32x4_FMA
|
||||
#define GGML_F32_VEC_ADD GGML_F32x4_ADD
|
||||
#define GGML_F32_VEC_MUL GGML_F32x4_MUL
|
||||
#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
|
||||
|
||||
// F16 s390x
|
||||
#define GGML_F16_STEP GGML_F32_STEP
|
||||
#define GGML_F16_EPR GGML_F32_EPR
|
||||
|
||||
static inline __vector float __lzs_f16cx4_load(const ggml_fp16_t * x) {
|
||||
float tmp[4];
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
tmp[i] = GGML_FP16_TO_FP32(x[i]);
|
||||
}
|
||||
|
||||
return vec_xl(0, tmp);
|
||||
}
|
||||
|
||||
static inline void __lzs_f16cx4_store(ggml_fp16_t * x, __vector float y) {
|
||||
float arr[4];
|
||||
|
||||
vec_xst(y, 0, arr);
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
x[i] = GGML_FP32_TO_FP16(arr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
#define GGML_F16_VEC GGML_F32x4
|
||||
#define GGML_F16_VEC_ZERO GGML_F32x4_ZERO
|
||||
#define GGML_F16_VEC_SET1 GGML_F32x4_SET1
|
||||
#define GGML_F16_VEC_LOAD(p, i) __lzs_f16cx4_load(p)
|
||||
#define GGML_F16_VEC_STORE(p, r, i) __lzs_f16cx4_store(p, r[i])
|
||||
#define GGML_F16_VEC_FMA GGML_F32x4_FMA
|
||||
#define GGML_F16_VEC_ADD GGML_F32x4_ADD
|
||||
#define GGML_F16_VEC_MUL GGML_F32x4_MUL
|
||||
#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
|
||||
|
||||
#endif
|
||||
|
||||
// GGML_F32_ARR / GGML_F16_ARR
|
||||
@@ -2381,15 +2465,20 @@ bool ggml_is_numa(void) {
|
||||
#define HWCAP2_I8MM (1 << 13)
|
||||
#endif
|
||||
|
||||
#if !defined(HWCAP2_SME)
|
||||
#define HWCAP2_SME (1 << 23)
|
||||
#endif
|
||||
|
||||
static void ggml_init_arm_arch_features(void) {
|
||||
#if defined(__linux__) && defined(__aarch64__)
|
||||
uint32_t hwcap = getauxval(AT_HWCAP);
|
||||
uint32_t hwcap2 = getauxval(AT_HWCAP2);
|
||||
|
||||
ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
|
||||
ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
|
||||
ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP);
|
||||
ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
|
||||
ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
|
||||
ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
|
||||
ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
|
||||
ggml_arm_arch_features.has_sme = !!(hwcap2 & HWCAP2_SME);
|
||||
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
|
||||
@@ -2412,6 +2501,11 @@ static void ggml_init_arm_arch_features(void) {
|
||||
}
|
||||
ggml_arm_arch_features.has_i8mm = oldp;
|
||||
|
||||
if (sysctlbyname("hw.optional.arm.FEAT_SME", &oldp, &size, NULL, 0) != 0) {
|
||||
oldp = 0;
|
||||
}
|
||||
ggml_arm_arch_features.has_sme = oldp;
|
||||
|
||||
ggml_arm_arch_features.has_sve = 0;
|
||||
ggml_arm_arch_features.sve_cnt = 0;
|
||||
#else
|
||||
@@ -2435,6 +2529,12 @@ static void ggml_init_arm_arch_features(void) {
|
||||
ggml_arm_arch_features.has_sve = 0;
|
||||
ggml_arm_arch_features.sve_cnt = 0;
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_SME2)
|
||||
ggml_arm_arch_features.has_sme = 1;
|
||||
#else
|
||||
ggml_arm_arch_features.has_sme = 0;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
@@ -14402,6 +14502,14 @@ int ggml_cpu_has_vsx(void) {
|
||||
#endif
|
||||
}
|
||||
|
||||
int ggml_cpu_has_vxe(void) {
|
||||
#if defined(__VXE__) || defined(__VXE2__)
|
||||
return 1;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
int ggml_cpu_has_neon(void) {
|
||||
#if defined(__ARM_ARCH) && defined(__ARM_NEON)
|
||||
return ggml_arm_arch_features.has_neon;
|
||||
@@ -14442,6 +14550,14 @@ int ggml_cpu_get_sve_cnt(void) {
|
||||
#endif
|
||||
}
|
||||
|
||||
int ggml_cpu_has_sme(void) {
|
||||
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SME)
|
||||
return ggml_arm_arch_features.has_sme;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_cpu_init(void) {
|
||||
// needed to initialize f16 tables
|
||||
{
|
||||
|
||||
@@ -14,6 +14,10 @@
|
||||
#include "ggml-cpu-hbm.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_CPU_KLEIDIAI
|
||||
#include "kleidiai/kleidiai.h"
|
||||
#endif
|
||||
|
||||
#if defined(__APPLE__)
|
||||
#include <sys/types.h>
|
||||
#include <sys/sysctl.h>
|
||||
@@ -39,6 +43,12 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_CPU_KLEIDIAI
|
||||
if (ggml_backend_cpu_kleidiai_buffer_type()) {
|
||||
bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type());
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_CPU_AARCH64
|
||||
if (ggml_backend_cpu_aarch64_buffer_type()) {
|
||||
bufts.push_back(ggml_backend_cpu_aarch64_buffer_type());
|
||||
@@ -538,12 +548,18 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
|
||||
static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());
|
||||
features.push_back({ "SVE_CNT", sve_cnt.c_str() });
|
||||
}
|
||||
if (ggml_cpu_has_sme()) {
|
||||
features.push_back({ "SME", "1" });
|
||||
}
|
||||
if (ggml_cpu_has_riscv_v()) {
|
||||
features.push_back({ "RISCV_V", "1" });
|
||||
}
|
||||
if (ggml_cpu_has_vsx()) {
|
||||
features.push_back({ "VSX", "1" });
|
||||
}
|
||||
if (ggml_cpu_has_vxe()) {
|
||||
features.push_back({ "VXE", "1" });
|
||||
}
|
||||
if (ggml_cpu_has_wasm_simd()) {
|
||||
features.push_back({ "WASM_SIMD", "1" });
|
||||
}
|
||||
@@ -559,6 +575,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
|
||||
#ifdef GGML_USE_OPENMP
|
||||
features.push_back({ "OPENMP", "1" });
|
||||
#endif
|
||||
#ifdef GGML_USE_CPU_KLEIDIAI
|
||||
features.push_back({ "KLEIDIAI", "1" });
|
||||
#endif
|
||||
#ifdef GGML_USE_CPU_AARCH64
|
||||
features.push_back({ "AARCH64_REPACK", "1" });
|
||||
#endif
|
||||
|
||||
259
ggml/src/ggml-cpu/kleidiai/kernels.cpp
Normal file
259
ggml/src/ggml-cpu/kleidiai/kernels.cpp
Normal file
@@ -0,0 +1,259 @@
|
||||
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
// KleidiAI micro-kernels
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
|
||||
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
||||
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
|
||||
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
|
||||
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
|
||||
#include "kai_common.h"
|
||||
|
||||
#include "kernels.h"
|
||||
|
||||
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
||||
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
#if defined(__ARM_FEATURE_SME)
|
||||
{
|
||||
/* SME GEMM */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
},
|
||||
/* SME GEMV */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
},
|
||||
/* .lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .require_aligned_m_idx = */ true,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_SME,
|
||||
},
|
||||
#endif
|
||||
#if defined(__APPLE__)
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
{
|
||||
/* DOTPROD GEMM */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
},
|
||||
/* DOTPROD GEMV */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
},
|
||||
/* .lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .require_aligned_m_idx = */ false,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
||||
},
|
||||
#endif
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
{
|
||||
/* i8mm GEMM */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
},
|
||||
/* i8mm GEMV */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
},
|
||||
/* .lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .require_aligned_m_idx = */ false,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||
},
|
||||
#endif
|
||||
#else
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
{
|
||||
/* i8mm GEMM */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
},
|
||||
/* i8mm GEMV */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
},
|
||||
/* .lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .require_aligned_m_idx = */ false,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||
},
|
||||
#endif
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
{
|
||||
/* DOTPROD GEMM */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
},
|
||||
/* DOTPROD GEMV */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
},
|
||||
/* .lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .require_aligned_m_idx = */ false,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
||||
},
|
||||
#endif
|
||||
#endif
|
||||
};
|
||||
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) {
|
||||
ggml_kleidiai_kernels * kernels = nullptr;
|
||||
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
||||
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
|
||||
kernels = &gemm_gemv_kernels[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return kernels;
|
||||
}
|
||||
61
ggml/src/ggml-cpu/kleidiai/kernels.h
Normal file
61
ggml/src/ggml-cpu/kleidiai/kernels.h
Normal file
@@ -0,0 +1,61 @@
|
||||
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
enum cpu_feature {
|
||||
CPU_FEATURE_NONE = 0,
|
||||
CPU_FEATURE_DOTPROD = 1,
|
||||
CPU_FEATURE_I8MM = 2,
|
||||
CPU_FEATURE_SVE = 4,
|
||||
CPU_FEATURE_SME = 8
|
||||
};
|
||||
inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) {
|
||||
lhs = static_cast<cpu_feature>(lhs | rhs);
|
||||
return lhs;
|
||||
}
|
||||
inline cpu_feature operator|(cpu_feature lhs, cpu_feature rhs) {
|
||||
return static_cast<cpu_feature>(static_cast<int>(lhs) | static_cast<int>(rhs));
|
||||
}
|
||||
|
||||
struct kernel_info {
|
||||
size_t (*get_m_step)(void);
|
||||
size_t (*get_n_step)(void);
|
||||
size_t (*get_mr)(void);
|
||||
size_t (*get_nr)(void);
|
||||
size_t (*get_kr)(void);
|
||||
size_t (*get_sr)(void);
|
||||
size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl);
|
||||
size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl);
|
||||
size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
|
||||
size_t (*get_dst_size)(size_t m, size_t n);
|
||||
void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
|
||||
float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max);
|
||||
};
|
||||
|
||||
struct lhs_packing_info {
|
||||
size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
|
||||
size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
|
||||
size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
|
||||
void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
|
||||
size_t lhs_stride, void* lhs_packed);
|
||||
bool require_aligned_m_idx;
|
||||
};
|
||||
|
||||
struct rhs_packing_info {
|
||||
size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);
|
||||
void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
|
||||
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params);
|
||||
};
|
||||
|
||||
struct ggml_kleidiai_kernels {
|
||||
kernel_info gemm;
|
||||
kernel_info gemv;
|
||||
lhs_packing_info lhs_info;
|
||||
rhs_packing_info rhs_info;
|
||||
|
||||
cpu_feature required_cpu;
|
||||
};
|
||||
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features);
|
||||
287
ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
Normal file
287
ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
Normal file
@@ -0,0 +1,287 @@
|
||||
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
#include <arm_neon.h>
|
||||
#include <assert.h>
|
||||
#include <cfloat>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
#if defined(__linux__)
|
||||
#include <asm/hwcap.h>
|
||||
#include <sys/auxv.h>
|
||||
#elif defined(__APPLE__)
|
||||
#include <string_view>
|
||||
#include <sys/sysctl.h>
|
||||
#include <sys/types.h>
|
||||
#elif defined(_WIN32)
|
||||
#include <windows.h>
|
||||
#include <excpt.h>
|
||||
#endif
|
||||
|
||||
#include "kleidiai.h"
|
||||
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-threading.h"
|
||||
#include "ggml-cpu-traits.h"
|
||||
|
||||
#include "kernels.h"
|
||||
|
||||
#include "kai_common.h"
|
||||
|
||||
#define GGML_COMMON_DECL_CPP
|
||||
#include "ggml-common.h"
|
||||
|
||||
struct ggml_kleidiai_context {
|
||||
ggml_kleidiai_kernels * kernels;
|
||||
} static ctx = { NULL };
|
||||
|
||||
static void init_kleidiai_context(void) {
|
||||
|
||||
ggml_critical_section_start();
|
||||
static bool initialized = false;
|
||||
|
||||
if (!initialized) {
|
||||
initialized = true;
|
||||
const char *env_var = getenv("GGML_KLEIDIAI_SME");
|
||||
int sme_enabled = 0;
|
||||
|
||||
cpu_feature features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
|
||||
(ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
|
||||
(ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
|
||||
|
||||
if (env_var) {
|
||||
sme_enabled = atoi(env_var);
|
||||
}
|
||||
|
||||
if (sme_enabled != 0) {
|
||||
features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
||||
}
|
||||
ctx.kernels = ggml_kleidiai_select_kernels(features);
|
||||
}
|
||||
ggml_critical_section_end();
|
||||
}
|
||||
|
||||
static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
|
||||
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
|
||||
return tensor->ne[dim];
|
||||
}
|
||||
|
||||
namespace ggml::cpu::kleidiai {
|
||||
class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
|
||||
GGML_ASSERT(ctx.kernels);
|
||||
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
|
||||
|
||||
size_t k = op->src[0]->ne[0];
|
||||
size_t m = op->src[1]->ne[1];
|
||||
|
||||
size_t mr = kernel->get_mr();
|
||||
size_t kr = kernel->get_kr();
|
||||
size_t sr = kernel->get_sr();
|
||||
|
||||
size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
|
||||
if (dst->op == GGML_OP_MUL_MAT) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT(ctx.kernels);
|
||||
kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
|
||||
lhs_packing_info * lhs_info = &ctx.kernels->lhs_info;
|
||||
|
||||
GGML_ASSERT(kernel);
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const size_t k = ne00;
|
||||
const size_t m = ne11;
|
||||
const size_t n = ne01;
|
||||
|
||||
const size_t n_step = kernel->get_n_step();
|
||||
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
|
||||
const size_t n_start = ith * num_n_per_thread;
|
||||
|
||||
size_t n_to_process = num_n_per_thread;
|
||||
if ((n_start + n_to_process) > n) {
|
||||
n_to_process = n - n_start;
|
||||
}
|
||||
|
||||
const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
|
||||
uint8_t * lhs_packed = (uint8_t*)params->wdata;
|
||||
const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
|
||||
|
||||
size_t mr = kernel->get_mr();
|
||||
size_t kr = kernel->get_kr();
|
||||
size_t sr = kernel->get_sr();
|
||||
|
||||
// Calculate number of columns to be processed per thread
|
||||
const bool use_multithread = lhs_info->require_aligned_m_idx && m <= mr ? false : true;
|
||||
const size_t num_m_per_thread = use_multithread ? kai_roundup(m, nth) / nth : m;
|
||||
const size_t m_start = ith * num_m_per_thread;
|
||||
size_t m_to_process = num_m_per_thread;
|
||||
if ((m_start + m_to_process) > m) {
|
||||
m_to_process = m - m_start;
|
||||
}
|
||||
|
||||
if(m_start < m) {
|
||||
// Transform LHS
|
||||
const size_t src_stride = src1->nb[1];
|
||||
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1]));
|
||||
const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr);
|
||||
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
||||
|
||||
lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, m_start, src_ptr, src_stride, lhs_packed_ptr);
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
// Perform the operation
|
||||
const size_t dst_stride = dst->nb[1];
|
||||
const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr);
|
||||
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0);
|
||||
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
|
||||
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
|
||||
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
|
||||
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
|
||||
|
||||
kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr,
|
||||
dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public:
|
||||
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
|
||||
GGML_ASSERT(ctx.kernels);
|
||||
const size_t n = tensor->ne[1];
|
||||
const size_t k = tensor->ne[0];
|
||||
size_t nr = ctx.kernels->gemm.get_nr();
|
||||
size_t kr = ctx.kernels->gemm.get_kr();
|
||||
size_t sr = ctx.kernels->gemm.get_sr();
|
||||
|
||||
#ifndef NDEBUG
|
||||
const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0);
|
||||
GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
|
||||
#endif
|
||||
struct kai_rhs_pack_qs4cxs1s0_param params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, ¶ms);
|
||||
|
||||
return 0;
|
||||
|
||||
GGML_UNUSED(data_size);
|
||||
}
|
||||
};
|
||||
|
||||
static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {
|
||||
static tensor_traits traits;
|
||||
return &traits;
|
||||
}
|
||||
} // namespace ggml::cpu::kleidiai
|
||||
|
||||
static void ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
||||
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
|
||||
const void * data, size_t offset, size_t size) {
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(size == ggml_nbytes(tensor));
|
||||
|
||||
auto tensor_traits = (ggml::cpu::kleidiai::tensor_traits *) tensor->extra;
|
||||
auto OK = tensor_traits->repack(tensor, data, size);
|
||||
|
||||
GGML_ASSERT(OK == 0);
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
||||
return "CPU_KLEIDIAI";
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
|
||||
|
||||
if (buffer == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
buffer->buft = buft;
|
||||
buffer->iface.init_tensor = ggml_backend_cpu_kleidiai_buffer_init_tensor;
|
||||
buffer->iface.set_tensor = ggml_backend_cpu_kleidiai_buffer_set_tensor;
|
||||
buffer->iface.get_tensor = nullptr;
|
||||
buffer->iface.cpy_tensor = nullptr;
|
||||
return buffer;
|
||||
}
|
||||
|
||||
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||
return TENSOR_ALIGNMENT;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
namespace ggml::cpu::kleidiai {
|
||||
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
||||
if ( op->op == GGML_OP_MUL_MAT &&
|
||||
op->src[0]->type == GGML_TYPE_Q4_0 &&
|
||||
op->src[0]->buffer &&
|
||||
(ggml_n_dims(op->src[0]) == 2) &&
|
||||
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels
|
||||
) {
|
||||
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type == GGML_TYPE_F32 &&
|
||||
ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
|
||||
if (op->op == GGML_OP_MUL_MAT) {
|
||||
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
||||
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
} // namespace ggml::cpu::kleidiai
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
|
||||
static ggml::cpu::kleidiai::extra_buffer_type ctx;
|
||||
static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = {
|
||||
/* .iface = */ {
|
||||
/* .get_name = */ ggml_backend_cpu_kleidiai_buffer_type_get_name,
|
||||
/* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
|
||||
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
|
||||
/* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
|
||||
/* .is_host = */ nullptr,
|
||||
},
|
||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
|
||||
/* .context = */ &ctx,
|
||||
};
|
||||
|
||||
init_kleidiai_context();
|
||||
|
||||
return &ggml_backend_cpu_buffer_type_kleidiai;
|
||||
}
|
||||
17
ggml/src/ggml-cpu/kleidiai/kleidiai.h
Normal file
17
ggml/src/ggml-cpu/kleidiai/kleidiai.h
Normal file
@@ -0,0 +1,17 @@
|
||||
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ggml-alloc.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -7,7 +7,7 @@ if (CUDAToolkit_FOUND)
|
||||
|
||||
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
|
||||
# native == GPUs available at build time
|
||||
# 52 == Maxwell, lowest CUDA 12 standard
|
||||
# 50 == Maxwell, lowest CUDA 12 standard
|
||||
# 60 == P100, FP16 CUDA intrinsics
|
||||
# 61 == Pascal, __dp4a instruction (per-byte integer dot product)
|
||||
# 70 == V100, FP16 tensor cores
|
||||
@@ -17,7 +17,7 @@ if (CUDAToolkit_FOUND)
|
||||
elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
|
||||
set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;80")
|
||||
else()
|
||||
set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75;80")
|
||||
set(CMAKE_CUDA_ARCHITECTURES "50;61;70;75;80")
|
||||
endif()
|
||||
endif()
|
||||
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
||||
@@ -69,6 +69,10 @@ if (CUDAToolkit_FOUND)
|
||||
add_compile_definitions(GGML_CUDA_NO_VMM)
|
||||
endif()
|
||||
|
||||
if (NOT GGML_CUDA_FA)
|
||||
add_compile_definitions(GGML_CUDA_NO_FA)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
|
||||
add_compile_definitions(GGML_CUDA_F16)
|
||||
endif()
|
||||
|
||||
@@ -41,12 +41,13 @@
|
||||
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
|
||||
#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
|
||||
|
||||
#define GGML_CUDA_CC_PASCAL 600
|
||||
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
||||
#define GGML_CUDA_CC_VOLTA 700
|
||||
#define GGML_CUDA_CC_TURING 750
|
||||
#define GGML_CUDA_CC_AMPERE 800
|
||||
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
|
||||
#define GGML_CUDA_CC_PASCAL 600
|
||||
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
||||
#define GGML_CUDA_CC_VOLTA 700
|
||||
#define GGML_CUDA_CC_TURING 750
|
||||
#define GGML_CUDA_CC_AMPERE 800
|
||||
#define GGML_CUDA_CC_ADA_LOVELACE 890
|
||||
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
|
||||
|
||||
// GCN/CNDA, wave size is 64
|
||||
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
|
||||
@@ -199,9 +200,13 @@ typedef float2 dfloat2;
|
||||
#define NEW_MMA_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||
|
||||
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#define CP_ASYNC_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
|
||||
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
#define FLASH_ATTN_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
|
||||
static bool fp16_available(const int cc) {
|
||||
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
|
||||
@@ -231,6 +236,10 @@ static bool new_mma_available(const int cc) {
|
||||
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
||||
}
|
||||
|
||||
static bool cp_async_available(const int cc) {
|
||||
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
|
||||
}
|
||||
|
||||
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
return __AMDGCN_WAVEFRONT_SIZE;
|
||||
@@ -402,13 +411,13 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
||||
|
||||
#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
|
||||
return __dp4a(a, b, c);
|
||||
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
|
||||
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
|
||||
const int8_t * a8 = (const int8_t *) &a;
|
||||
const int8_t * b8 = (const int8_t *) &b;
|
||||
return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
|
||||
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
}
|
||||
|
||||
46
ggml/src/ggml-cuda/cp-async.cuh
Normal file
46
ggml/src/ggml-cuda/cp-async.cuh
Normal file
@@ -0,0 +1,46 @@
|
||||
// Simplified API for asynchronous data loading.
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
// Copies data from global to shared memory, cg == cache global.
|
||||
// Both the src and dst pointers must be aligned to 16 bit.
|
||||
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
|
||||
// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared.
|
||||
// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements.
|
||||
template <int preload>
|
||||
static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) {
|
||||
static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload");
|
||||
#ifdef CP_ASYNC_AVAILABLE
|
||||
#if CUDART_VERSION >= 11040
|
||||
if (preload == 256) {
|
||||
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;"
|
||||
: : "r"(dst), "l"(src));
|
||||
} else if (preload == 128) {
|
||||
asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;"
|
||||
: : "r"(dst), "l"(src));
|
||||
} else if (preload == 64) {
|
||||
asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;"
|
||||
: : "r"(dst), "l"(src));
|
||||
} else
|
||||
#endif // CUDART_VERSION >= 11040
|
||||
{
|
||||
asm volatile("cp.async.cg.shared.global [%0], [%1], 16;"
|
||||
: : "r"(dst), "l"(src));
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // CP_ASYNC_AVAILABLE
|
||||
}
|
||||
|
||||
// Makes each thread wait until its asynchronous data copies are done.
|
||||
// This does NOT provide any additional synchronization.
|
||||
// In particular, when copying data with multiple warps a call to __syncthreads will be needed.
|
||||
static __device__ __forceinline__ void cp_async_wait_all() {
|
||||
#ifdef CP_ASYNC_AVAILABLE
|
||||
asm volatile("cp.async.wait_all;");
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // CP_ASYNC_AVAILABLE
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "cpy.cuh"
|
||||
#include "dequantize.cuh"
|
||||
|
||||
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
||||
|
||||
@@ -82,13 +83,14 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
|
||||
const block_q8_0 * xi = (const block_q8_0 *) cxi;
|
||||
float * dsti = (float *) cdsti;
|
||||
float * cdstf = (float *)(cdsti);
|
||||
|
||||
const float d = (float)xi->d;
|
||||
|
||||
for (int j = 0; j < QK8_0; j++) {
|
||||
dsti[j] = xi->qs[j] * d;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < QK8_0; j += 2) {
|
||||
dfloat2 dq;
|
||||
dequantize_q8_0(cxi, 0, j, dq);
|
||||
*(cdstf + j) = dq.x;
|
||||
*(cdstf + j + 1) = dq.y;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -225,6 +227,18 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
|
||||
memcpy(dsti->qh, &qh, sizeof(qh));
|
||||
}
|
||||
|
||||
template<dequantize_kernel_t dequant, int qk>
|
||||
static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
|
||||
float * cdstf = (float *)(cdsti);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < qk/2; j++) {
|
||||
dfloat2 dq;
|
||||
dequant(cxi, 0, j, dq);
|
||||
*(cdstf + j) = dq.x;
|
||||
*(cdstf + j + qk/2) = dq.y;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
|
||||
if (x <= val[0]) return 0;
|
||||
@@ -387,6 +401,19 @@ static void ggml_cpy_f32_q4_0_cuda(
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q4_0_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02,
|
||||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q4_1_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
@@ -398,6 +425,19 @@ static void ggml_cpy_f32_q4_1_cuda(
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q4_1_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02,
|
||||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q5_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
@@ -409,6 +449,19 @@ static void ggml_cpy_f32_q5_0_cuda(
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q5_0_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02,
|
||||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q5_1_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
@@ -420,6 +473,19 @@ static void ggml_cpy_f32_q5_1_cuda(
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q5_1_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02,
|
||||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_iq4_nl_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
@@ -488,14 +554,25 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
@@ -524,14 +601,22 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
|
||||
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
|
||||
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
|
||||
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
|
||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
|
||||
@@ -123,13 +123,13 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
|
||||
|
||||
if (nbytes_shared <= smpbo) {
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||
if (!shared_memory_limit_raised[id]) {
|
||||
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
|
||||
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
|
||||
shared_memory_limit_raised[id] = true;
|
||||
}
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
||||
} else {
|
||||
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
||||
@@ -175,13 +175,13 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
|
||||
if (nbytes_shared <= smpbo) {
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||
if (!shared_memory_limit_raised[id]) {
|
||||
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
|
||||
shared_memory_limit_raised[id] = true;
|
||||
}
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
||||
} else {
|
||||
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
||||
|
||||
@@ -516,27 +516,25 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
||||
nullptr;
|
||||
}
|
||||
|
||||
// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wpass-failed"
|
||||
#endif // __clang__
|
||||
|
||||
template<int D, int ncols, int KQ_stride> // D == head size
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void flash_attn_stream_k_fixup(
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
||||
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
||||
|
||||
const int iter_k = ne11 / KQ_stride;
|
||||
const int iter_j = (ne01 + (ncols - 1)) / ncols;
|
||||
constexpr int ncols = ncols1*ncols2;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
const int j = blockIdx.y;
|
||||
const int c = blockIdx.z;
|
||||
const int jc = j*ncols2 + c;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
|
||||
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
|
||||
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
||||
|
||||
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
|
||||
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
|
||||
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
||||
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
||||
@@ -548,22 +546,22 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
const int channel = kbc0 / (iter_k*iter_j);
|
||||
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
|
||||
|
||||
dst += jt*ncols*ne02*D + channel*D;
|
||||
if (jt*ncols1 + j >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
|
||||
|
||||
// Load the partial result that needs a fixup:
|
||||
float dst_val[ncols] = {0.0f};
|
||||
float max_val[ncols] = {0.0f};
|
||||
float rowsum[ncols] = {0.0f};
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (jt*ncols + j >= ne01) {
|
||||
break;
|
||||
}
|
||||
dst_val[j] = dst[j*ne02*D + threadIdx.x];
|
||||
float dst_val = 0.0f;
|
||||
float max_val = 0.0f;
|
||||
float rowsum = 0.0f;
|
||||
{
|
||||
dst_val = *dst;
|
||||
|
||||
const float2 tmp = dst_fixup[bidx0*ncols + j];
|
||||
max_val[j] = tmp.x;
|
||||
rowsum[j] = tmp.y;
|
||||
const float2 tmp = dst_fixup[bidx0*ncols + jc];
|
||||
max_val = tmp.x;
|
||||
rowsum = tmp.y;
|
||||
}
|
||||
|
||||
// Iterate over previous blocks and compute the combined results.
|
||||
@@ -571,36 +569,30 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
int bidx = bidx0 - 1;
|
||||
int kbc_stop = kbc0;
|
||||
while(true) {
|
||||
const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
|
||||
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
if (kbc == kbc_stop) { // Did not have any data.
|
||||
bidx--;
|
||||
kbc_stop = kbc;
|
||||
continue;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (jt*ncols + j >= ne01) {
|
||||
break;
|
||||
}
|
||||
const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
|
||||
const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
|
||||
|
||||
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j];
|
||||
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];
|
||||
|
||||
// Scale the current and new value accumulators depending on the max. values.
|
||||
const float max_val_new = fmaxf(max_val[j], tmp.x);
|
||||
// Scale the current and new value accumulators depending on the max. values.
|
||||
const float max_val_new = fmaxf(max_val, tmp.x);
|
||||
|
||||
const float diff_val = max_val[j] - max_val_new;
|
||||
const float diff_add = tmp.x - max_val_new;
|
||||
const float diff_val = max_val - max_val_new;
|
||||
const float diff_add = tmp.x - max_val_new;
|
||||
|
||||
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
|
||||
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
|
||||
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
|
||||
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
|
||||
|
||||
dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
|
||||
rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y;
|
||||
dst_val = scale_val*dst_val + scale_add*dst_add;
|
||||
rowsum = scale_val*rowsum + scale_add*tmp.y;
|
||||
|
||||
max_val[j] = max_val_new;
|
||||
}
|
||||
max_val = max_val_new;
|
||||
|
||||
// If this block started in a previous tile we are done and don't need to combine additional partial results.
|
||||
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
|
||||
@@ -611,19 +603,9 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
}
|
||||
|
||||
// Write back final result:
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (jt*ncols + j >= ne01) {
|
||||
return;
|
||||
}
|
||||
dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
|
||||
}
|
||||
*dst = dst_val / rowsum;
|
||||
}
|
||||
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif // __clang__
|
||||
|
||||
template<int D, int parallel_blocks> // D == head size
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
@@ -690,11 +672,13 @@ static void on_no_fattn_vec_case(const int D) {
|
||||
}
|
||||
|
||||
// parallel_blocks == 0 is stream-k decomposition
|
||||
template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
|
||||
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
|
||||
void launch_fattn(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
|
||||
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
|
||||
) {
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
@@ -716,7 +700,9 @@ void launch_fattn(
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int nsm = ggml_cuda_info().devices[id].nsm;
|
||||
|
||||
ggml_cuda_pool_alloc<half> K_f16(pool);
|
||||
ggml_cuda_pool_alloc<half> V_f16(pool);
|
||||
@@ -761,24 +747,26 @@ void launch_fattn(
|
||||
nb23 = nb23*bs*sizeof(half)/ts;
|
||||
}
|
||||
|
||||
const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block);
|
||||
const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
|
||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
||||
|
||||
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||
dim3 blocks_num;
|
||||
if (parallel_blocks == 0) {
|
||||
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
||||
const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm;
|
||||
const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total;
|
||||
const bool short_context = K->ne[1] < 4096;
|
||||
const int max_blocks = 2*nsm;
|
||||
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
|
||||
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
|
||||
|
||||
const int nblocks_stream_k = 2*nsm;
|
||||
const int nblocks_stream_k = max_blocks;
|
||||
|
||||
blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
|
||||
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
|
||||
|
||||
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
|
||||
blocks_num.y = 1;
|
||||
blocks_num.z = 1;
|
||||
|
||||
dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float));
|
||||
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
|
||||
} else {
|
||||
blocks_num.x = parallel_blocks*ntiles_x;
|
||||
blocks_num.y = Q->ne[2];
|
||||
@@ -790,7 +778,6 @@ void launch_fattn(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
@@ -827,11 +814,11 @@ void launch_fattn(
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
if constexpr (parallel_blocks == 0) {
|
||||
if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 blocks_num_combine = blocks_num;
|
||||
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
|
||||
|
||||
flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
|
||||
flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -44,12 +44,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifdef FP16_AVAILABLE
|
||||
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
@@ -290,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP16_AVAILABLE
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
}
|
||||
|
||||
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
|
||||
@@ -302,14 +297,14 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||
|
||||
@@ -44,10 +44,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
@@ -285,6 +282,9 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
|
||||
@@ -296,14 +296,14 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||
|
||||
@@ -41,12 +41,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifdef FP16_AVAILABLE
|
||||
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
@@ -300,7 +295,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP16_AVAILABLE
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||
@@ -310,7 +305,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
|
||||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
|
||||
}
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
|
||||
@@ -41,10 +41,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
@@ -281,6 +278,9 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||
@@ -290,7 +290,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
|
||||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
|
||||
}
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
|
||||
@@ -51,7 +51,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
@@ -425,7 +425,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
}
|
||||
|
||||
constexpr int get_max_power_of_2(int x) {
|
||||
@@ -478,7 +478,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
||||
return;
|
||||
}
|
||||
if (2*blocks_num_pb1 < 2*nsm) {
|
||||
@@ -493,7 +493,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
||||
return;
|
||||
}
|
||||
constexpr int parallel_blocks = 1;
|
||||
@@ -507,7 +507,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
@@ -8,28 +8,50 @@
|
||||
#include "fattn-wmma-f16.cuh"
|
||||
#include "fattn.cuh"
|
||||
|
||||
template <int cols_per_block>
|
||||
template <int D, int ncols2>
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
if (Q->ne[1] <= 8/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 16/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<D, 16/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<D, 64/ncols2, ncols2>(ctx, dst);
|
||||
}
|
||||
|
||||
template <int ncols2>
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -38,24 +60,35 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context
|
||||
}
|
||||
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
if (Q->ne[1] <= 8) {
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
const float use_gqa_opt = mask && max_bias == 0.0f;
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
|
||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 16) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst);
|
||||
if (use_gqa_opt && gqa_ratio == 4) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst);
|
||||
if (use_gqa_opt && gqa_ratio == 2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
|
||||
}
|
||||
|
||||
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
||||
@@ -209,8 +242,11 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
@@ -252,7 +288,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
|
||||
K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
|
||||
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) {
|
||||
if (prec == GGML_PREC_DEFAULT) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
return;
|
||||
|
||||
@@ -261,6 +261,12 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n",
|
||||
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
|
||||
device_vmm ? "yes" : "no", prop.warpSize);
|
||||
#elif defined(GGML_USE_MUSA)
|
||||
// TODO: refine the .cc to reflect MUSA's actual CC capabilities
|
||||
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
|
||||
info.devices[id].cc = 100*prop.major + 10*prop.minor;
|
||||
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
|
||||
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
|
||||
#else
|
||||
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
|
||||
info.devices[id].cc = 100*prop.major + 10*prop.minor;
|
||||
@@ -1782,9 +1788,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
||||
}
|
||||
}
|
||||
#else
|
||||
#ifdef GGML_USE_MUSA
|
||||
GGML_ASSERT(false);
|
||||
#else // !GGML_USE_MUSA
|
||||
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
||||
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
||||
// use cublasGemmStridedBatchedEx
|
||||
@@ -1827,7 +1830,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
||||
cu_compute_type,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
#endif // GGML_USE_MUSA
|
||||
#endif
|
||||
|
||||
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||
@@ -3073,15 +3075,27 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
|
||||
return true;
|
||||
}
|
||||
@@ -3189,7 +3203,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_FLASH_ATTN_EXT: {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
return false;
|
||||
#endif
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -4,11 +4,12 @@
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
|
||||
//
|
||||
// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
|
||||
// A is a row-major matrix with shape I x K.
|
||||
// B is a column-major matrix with shape K x J.
|
||||
// C is a column-major matrix with shape I x J.
|
||||
// Note that along their lowest dimension I, J, and K are measured in physical 32 bit elements instead of logical elements.
|
||||
// The functions get_i, get_j, and get_k can be used to get the physical 32 bit index of the lth element of a thread within a tile.
|
||||
// A is a row-major matrix with shape M x K.
|
||||
// B is a column-major matrix with shape K x N.
|
||||
// C is a column-major matrix with shape M x N.
|
||||
// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns.
|
||||
// Note that J is measured in physical 32 bit elements instead of logical elements.
|
||||
// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
|
||||
// All matrix tiles have ne physical 32 bit elements per warp.
|
||||
//
|
||||
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
|
||||
@@ -23,7 +24,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
|
||||
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
|
||||
: "+r"(ret) : "r"(x));
|
||||
: "=r"(ret) : "r"(x));
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(NEW_MMA_AVAILABLE)
|
||||
@@ -52,407 +53,342 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
|
||||
|
||||
#endif // CUDART_VERSION >= 11080
|
||||
|
||||
static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
|
||||
half2 ret;
|
||||
*((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct mma_A_I16K4 {
|
||||
static_assert(sizeof(T) == 4, "bad type size");
|
||||
namespace ggml_cuda_mma {
|
||||
|
||||
static constexpr int I = 16;
|
||||
static constexpr int K = 4;
|
||||
static constexpr int ne = 2;
|
||||
template <int I_, int J_, typename T>
|
||||
struct tile {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
static constexpr int ne = I * J / WARP_SIZE;
|
||||
T x[ne] = {0};
|
||||
|
||||
T x[ne];
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 8 && (J == 4 || J == 8)) {
|
||||
return threadIdx.x / 4;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return (l / 2) * 8 + threadIdx.x / 4;
|
||||
} else if constexpr (I == 16 && J == 16) {
|
||||
return ((l / 2) % 2) * 8 + threadIdx.x / 4;
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
const int ret = (l%2) * (I/2) + threadIdx.x / K;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < I);
|
||||
return ret;
|
||||
}
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 8 && J == 4) {
|
||||
return threadIdx.x % 4;
|
||||
} else if constexpr (I == 8 && J == 8) {
|
||||
return 4 * l + threadIdx.x % 4;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return 2 * (threadIdx.x % 4) + l % 2;
|
||||
} else if constexpr (I == 16 && J == 16) {
|
||||
return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static __device__ __forceinline__ int get_k(const int /* l */) {
|
||||
const int ret = threadIdx.x % K;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < K);
|
||||
return ret;
|
||||
}
|
||||
template <int I_, int J_>
|
||||
struct tile<I_, J_, half2> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
static constexpr int ne = I * J / WARP_SIZE;
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 8 && J == 8) {
|
||||
return threadIdx.x / 4;
|
||||
} else if constexpr (I == 16 && J == 4) {
|
||||
return l * 8 + threadIdx.x / 4;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return (l % 2) * 8 + threadIdx.x / 4;
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 8 && J == 8) {
|
||||
return l * 4 + threadIdx.x % 4;
|
||||
} else if constexpr (I == 16 && J == 4) {
|
||||
return threadIdx.x % 4;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return (l / 2) * 4 + threadIdx.x % 4;
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int I, int J>
|
||||
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
||||
tile<I, J/2, half2> ret;
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
x[l] = xs0[get_i(l)*stride + get_k(l)];
|
||||
for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
|
||||
ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
|
||||
tile<8, 8, half2> ret;
|
||||
ret.x[0] = ggml_cuda_movmatrix(t.x[0]);
|
||||
ret.x[1] = ggml_cuda_movmatrix(t.x[1]);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <int I, int J, typename T>
|
||||
static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < t.ne; ++l) {
|
||||
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ void load_ldmatrix(
|
||||
tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
int * xi = (int *) x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride;
|
||||
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
||||
: "+r"(xi[0]), "+r"(xi[1])
|
||||
int * xi = (int *) t.x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
||||
: "=r"(xi[0]), "=r"(xi[1])
|
||||
: "l"(xs));
|
||||
#else
|
||||
load_generic(t, xs0, stride);
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ void load_ldmatrix(
|
||||
tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
int * xi = (int *) t.x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
||||
: "=r"(xi[0]), "=r"(xi[1])
|
||||
: "l"(xs));
|
||||
#else
|
||||
load_generic(xs0, stride);
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct mma_A_I16K8 {
|
||||
static_assert(sizeof(T) == 4, "bad type size");
|
||||
|
||||
static constexpr int I = 16;
|
||||
static constexpr int K = 8;
|
||||
static constexpr int ne = 4;
|
||||
|
||||
T x[ne];
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < I);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_k(const int l) {
|
||||
const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < K);
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
x[l] = xs0[get_i(l)*stride + get_k(l)];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ void load_ldmatrix(
|
||||
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
int * xi = (int * ) x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
|
||||
asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
|
||||
int * xi = (int * ) t.x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
|
||||
: "l"(xs));
|
||||
#else
|
||||
load_generic(t, xs0, stride);
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ void load_ldmatrix_trans(
|
||||
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
int * xi = (int * ) t.x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
|
||||
: "l"(xs));
|
||||
#else
|
||||
GGML_UNUSED(t);
|
||||
GGML_UNUSED(xs0);
|
||||
GGML_UNUSED(stride);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load_ldmatrix_trans(const T * __restrict__ xs0, const int & stride) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
int * xi = (int * ) x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
|
||||
asm("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "+r"(xi[0]), "+r"(xi[2]), "+r"(xi[1]), "+r"(xi[3])
|
||||
: "l"(xs));
|
||||
#else
|
||||
GGML_UNUSED(xs0);
|
||||
GGML_UNUSED(stride);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void transpose() {
|
||||
int * xi = (int *) x;
|
||||
xi[0] = ggml_cuda_movmatrix(xi[0]);
|
||||
|
||||
const int tmp = ggml_cuda_movmatrix(xi[1]);
|
||||
xi[1] = ggml_cuda_movmatrix(xi[2]);
|
||||
xi[2] = tmp;
|
||||
|
||||
xi[3] = ggml_cuda_movmatrix(xi[3]);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct mma_B_J8K4 {
|
||||
static_assert(sizeof(T) == 4, "bad type size");
|
||||
|
||||
static constexpr int J = 8;
|
||||
static constexpr int K = 4;
|
||||
static constexpr int ne = 1;
|
||||
|
||||
T x[ne];
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int /* l */) {
|
||||
const int ret = threadIdx.x / K;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < J);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_k(const int /* l */) {
|
||||
const int ret = threadIdx.x % K;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < K);
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
x[l] = xs0[get_j(l)*stride + get_k(l)];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
int * xi = (int *) x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride;
|
||||
asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
|
||||
: "+r"(xi[0]) : "l"(xs));
|
||||
#else
|
||||
load_generic(xs0, stride);
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct mma_B_J8K8 {
|
||||
static_assert(sizeof(T) == 4, "bad type size");
|
||||
|
||||
static constexpr int J = 8;
|
||||
static constexpr int K = 8;
|
||||
static constexpr int ne = 2;
|
||||
|
||||
T x[ne];
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int /* l */) {
|
||||
const int ret = threadIdx.x / (K/2);
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < J);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_k(const int l) {
|
||||
const int ret = l * (K/2) + threadIdx.x % (K/2);
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < K);
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
x[l] = xs0[get_j(l)*stride + get_k(l)];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
int * xi = (int *) x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
|
||||
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
||||
: "+r"(xi[0]), "+r"(xi[1])
|
||||
: "l"(xs));
|
||||
#else
|
||||
load_generic(xs0, stride);
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct mma_C_I16J8 {};
|
||||
|
||||
template <>
|
||||
struct mma_C_I16J8<int> {
|
||||
static constexpr int I = 16;
|
||||
static constexpr int J = 8;
|
||||
static constexpr int ne = 4;
|
||||
|
||||
int x[ne] = {0};
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < I);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < J);
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void mma(const mma_A_I16K4<int> & mma_A, const mma_B_J8K4<int> & mma_B) {
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
||||
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
|
||||
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
|
||||
: "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0]));
|
||||
#else
|
||||
// On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(x[0]), "+r"(x[1])
|
||||
: "r"(mma_A.x[0]), "r"(mma_B.x[0]));
|
||||
: "+r"(D.x[0]), "+r"(D.x[1])
|
||||
: "r"(A.x[0]), "r"(B.x[0]));
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(x[2]), "+r"(x[3])
|
||||
: "r"(mma_A.x[1]), "r"(mma_B.x[0]));
|
||||
: "+r"(D.x[2]), "+r"(D.x[3])
|
||||
: "r"(A.x[1]), "r"(B.x[0]));
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#else
|
||||
GGML_UNUSED(mma_A);
|
||||
GGML_UNUSED(mma_B);
|
||||
GGML_UNUSED(D);
|
||||
GGML_UNUSED(A);
|
||||
GGML_UNUSED(B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void mma(const mma_A_I16K8<int> & mma_A, const mma_B_J8K8<int> & mma_B) {
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
||||
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
||||
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
|
||||
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
|
||||
: "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1]));
|
||||
#else
|
||||
// On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(x[0]), "+r"(x[1])
|
||||
: "r"(mma_A.x[0]), "r"(mma_B.x[0]));
|
||||
: "+r"(D.x[0]), "+r"(D.x[1])
|
||||
: "r"(A.x[0]), "r"(B.x[0]));
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(x[2]), "+r"(x[3])
|
||||
: "r"(mma_A.x[1]), "r"(mma_B.x[0]));
|
||||
: "+r"(D.x[2]), "+r"(D.x[3])
|
||||
: "r"(A.x[1]), "r"(B.x[0]));
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(x[0]), "+r"(x[1])
|
||||
: "r"(mma_A.x[2]), "r"(mma_B.x[1]));
|
||||
: "+r"(D.x[0]), "+r"(D.x[1])
|
||||
: "r"(A.x[2]), "r"(B.x[1]));
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(x[2]), "+r"(x[3])
|
||||
: "r"(mma_A.x[3]), "r"(mma_B.x[1]));
|
||||
: "+r"(D.x[2]), "+r"(D.x[3])
|
||||
: "r"(A.x[3]), "r"(B.x[1]));
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#else
|
||||
GGML_UNUSED(mma_A);
|
||||
GGML_UNUSED(mma_B);
|
||||
GGML_UNUSED(D);
|
||||
GGML_UNUSED(A);
|
||||
GGML_UNUSED(B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mma_C_I16J8<half2> {
|
||||
static constexpr int I = 16;
|
||||
static constexpr int J = 4;
|
||||
static constexpr int ne = 2;
|
||||
|
||||
half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}};
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
const int ret = l * (I/2) + threadIdx.x / J;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < I);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int /* l */) {
|
||||
const int ret = threadIdx.x % J;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < J);
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
int * Axi = (int *) mma_A.x;
|
||||
int * Bxi = (int *) mma_B.x;
|
||||
int * xi = (int *) x;
|
||||
const int * Axi = (const int *) A.x;
|
||||
const int * Bxi = (const int *) B.x;
|
||||
int * Dxi = (int *) D.x;
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
|
||||
: "+r"(xi[0]), "+r"(xi[1])
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
|
||||
#else
|
||||
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
||||
: "+r"(xi[0]), "+r"(xi[1])
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
||||
: "+r"(xi[0]), "+r"(xi[1])
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#else
|
||||
GGML_UNUSED(mma_A);
|
||||
GGML_UNUSED(mma_B);
|
||||
GGML_UNUSED(D);
|
||||
GGML_UNUSED(A);
|
||||
GGML_UNUSED(B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
__device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
|
||||
mma_B_J8K8<half2> mma_B;
|
||||
|
||||
int * xi = (int *) x;
|
||||
int * Bxi = (int *) mma_B.x;
|
||||
Bxi[0] = ggml_cuda_movmatrix(xi[0]);
|
||||
Bxi[1] = ggml_cuda_movmatrix(xi[1]);
|
||||
|
||||
return mma_B;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mma_C_I16J8<float> {
|
||||
static constexpr int I = 16;
|
||||
static constexpr int J = 8;
|
||||
static constexpr int ne = 4;
|
||||
|
||||
float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < I);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < J);
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
int * Axi = (int *) mma_A.x;
|
||||
int * Bxi = (int *) mma_B.x;
|
||||
int * xi = (int *) x;
|
||||
const int * Axi = (const int *) A.x;
|
||||
const int * Bxi = (const int *) B.x;
|
||||
int * Dxi = (int *) D.x;
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
|
||||
: "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
|
||||
#else
|
||||
// On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
||||
: "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
||||
: "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#else
|
||||
GGML_UNUSED(D);
|
||||
GGML_UNUSED(A);
|
||||
GGML_UNUSED(B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
const int * Axi = (const int *) A.x;
|
||||
const int * Bxi = (const int *) B.x;
|
||||
int * Dxi = (int *) D.x;
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
||||
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
|
||||
#else
|
||||
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#else
|
||||
GGML_UNUSED(mma_A);
|
||||
GGML_UNUSED(mma_B);
|
||||
GGML_UNUSED(D);
|
||||
GGML_UNUSED(A);
|
||||
GGML_UNUSED(B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
__device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
|
||||
mma_B_J8K8<half2> mma_B;
|
||||
mma_B.x[0] = make_half2(x[0], x[1]);
|
||||
mma_B.x[1] = make_half2(x[2], x[3]);
|
||||
|
||||
int * Bxi = (int *) mma_B.x;
|
||||
Bxi[0] = ggml_cuda_movmatrix(Bxi[0]);
|
||||
Bxi[1] = ggml_cuda_movmatrix(Bxi[1]);
|
||||
|
||||
return mma_B;
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
const int * Axi = (const int *) A.x;
|
||||
const int * Bxi = (const int *) B.x;
|
||||
int * Dxi = (int *) D.x;
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
|
||||
#else
|
||||
// On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#else
|
||||
GGML_UNUSED(D);
|
||||
GGML_UNUSED(A);
|
||||
GGML_UNUSED(B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
x[l] = xs0[get_j(l)*stride + get_i(l)];
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
#include <climits>
|
||||
#include <cstdint>
|
||||
|
||||
using namespace ggml_cuda_mma;
|
||||
|
||||
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
|
||||
#define MMQ_ITER_K 256
|
||||
#define MMQ_NWARPS 8
|
||||
@@ -647,15 +649,15 @@ template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
|
||||
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
||||
|
||||
typedef mma_A_I16K8<int> mma_A;
|
||||
typedef mma_B_J8K8<int> mma_B;
|
||||
typedef mma_C_I16J8<int> mma_C;
|
||||
typedef tile<16, 8, int> tile_A;
|
||||
typedef tile< 8, 8, int> tile_B;
|
||||
typedef tile<16, 8, int> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = 2 * granularity;
|
||||
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
||||
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
||||
|
||||
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
||||
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
|
||||
|
||||
const int * x_qs = (const int *) x;
|
||||
const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
|
||||
@@ -663,8 +665,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||
const float * y_df = (const float *) y;
|
||||
const half2 * y_ds = (const half2 *) y;
|
||||
|
||||
mma_A A[ntx][WARP_SIZE/QI8_0];
|
||||
float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
|
||||
tile_A A[ntx][WARP_SIZE/QI8_0];
|
||||
float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0];
|
||||
|
||||
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
||||
|
||||
@@ -674,12 +676,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
A[n][k01/QI8_0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
||||
load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
|
||||
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
||||
@@ -691,17 +693,17 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
||||
mma_B B;
|
||||
float dB[mma_C::ne/2];
|
||||
tile_B B;
|
||||
float dB[tile_C::ne/2];
|
||||
|
||||
B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
|
||||
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int j = j0 + tile_C::get_j(l);
|
||||
|
||||
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
|
||||
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
||||
@@ -712,12 +714,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
mma_C C;
|
||||
C.mma(A[n][k01/QI8_0], B);
|
||||
tile_C C;
|
||||
mma(C, A[n][k01/QI8_0], B);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -758,23 +760,23 @@ template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
||||
|
||||
typedef mma_A_I16K8<int> mma_A;
|
||||
typedef mma_B_J8K8<int> mma_B;
|
||||
typedef mma_C_I16J8<int> mma_C;
|
||||
typedef tile<16, 8, int> tile_A;
|
||||
typedef tile< 8, 8, int> tile_B;
|
||||
typedef tile<16, 8, int> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = 2 * granularity;
|
||||
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
||||
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
||||
|
||||
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
||||
y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K);
|
||||
|
||||
const int * x_qs = (const int *) x;
|
||||
const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const half2 * y_dm = (const half2 *) y;
|
||||
|
||||
mma_A A[ntx][WARP_SIZE/QI8_1];
|
||||
float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
|
||||
tile_A A[ntx][WARP_SIZE/QI8_1];
|
||||
float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1];
|
||||
|
||||
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
||||
|
||||
@@ -784,12 +786,12 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
A[n][k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
|
||||
load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
|
||||
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
||||
@@ -801,30 +803,30 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
||||
mma_B B;
|
||||
float2 dsB[mma_C::ne/2];
|
||||
tile_B B;
|
||||
float2 dsB[tile_C::ne/2];
|
||||
|
||||
B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
|
||||
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int j = j0 + tile_C::get_j(l);
|
||||
|
||||
dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
mma_C C;
|
||||
C.mma(A[n][k01/QI8_1], B);
|
||||
tile_C C;
|
||||
mma(C, A[n][k01/QI8_1], B);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
|
||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
|
||||
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -868,26 +870,26 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
|
||||
typedef mma_A_I16K4<int> mma_A;
|
||||
typedef mma_A_I16K8<int> mma_A_K8;
|
||||
typedef mma_B_J8K4<int> mma_B;
|
||||
typedef mma_C_I16J8<int> mma_C;
|
||||
typedef tile<16, 4, int> tile_A;
|
||||
typedef tile<16, 8, int> tile_A_8;
|
||||
typedef tile< 8, 4, int> tile_B;
|
||||
typedef tile<16, 8, int> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = 2 * granularity;
|
||||
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
||||
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
||||
|
||||
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
||||
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
|
||||
|
||||
const int * x_qs = (const int *) x;
|
||||
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const float * y_df = (const float *) y;
|
||||
|
||||
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
||||
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
|
||||
|
||||
mma_A A[ntx][8];
|
||||
float dA[ntx][mma_C::ne/2][8];
|
||||
tile_A A[ntx][8];
|
||||
float dA[ntx][tile_C::ne/2][8];
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
@@ -895,12 +897,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
((mma_A_K8 *) A[n])[k01/8].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
||||
load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
||||
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
|
||||
@@ -912,32 +914,32 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
|
||||
mma_B B[2];
|
||||
float dB[mma_C::ne/2];
|
||||
tile_B B[2];
|
||||
float dB[tile_C::ne/2];
|
||||
|
||||
// Here load_generic is faster than load_ldmatrix.
|
||||
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
|
||||
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
|
||||
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
|
||||
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int j = j0 + tile_C::get_j(l);
|
||||
|
||||
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
mma_C C[2];
|
||||
C[0].mma(A[n][k01/4 + 0], B[0]);
|
||||
C[1].mma(A[n][k01/4 + 1], B[1]);
|
||||
tile_C C[2];
|
||||
mma(C[0], A[n][k01/4 + 0], B[0]);
|
||||
mma(C[1], A[n][k01/4 + 1], B[1]);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1056,27 +1058,27 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
|
||||
typedef mma_A_I16K4<int> mma_A;
|
||||
typedef mma_A_I16K8<int> mma_A_K8;
|
||||
typedef mma_B_J8K4<int> mma_B;
|
||||
typedef mma_C_I16J8<int> mma_C;
|
||||
typedef tile<16, 4, int> tile_A;
|
||||
typedef tile<16, 8, int> tile_A_8;
|
||||
typedef tile< 8, 4, int> tile_B;
|
||||
typedef tile<16, 8, int> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = 2 * granularity;
|
||||
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
||||
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
||||
|
||||
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
||||
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
|
||||
|
||||
const int * x_qs = (const int *) x;
|
||||
const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const half2 * y_ds = (const half2 *) y;
|
||||
|
||||
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
||||
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
|
||||
|
||||
mma_A A[ntx][8];
|
||||
float dA[ntx][mma_C::ne/2][8];
|
||||
float mA[ntx][mma_C::ne/2][8];
|
||||
tile_A A[ntx][8];
|
||||
float dA[ntx][tile_C::ne/2][8];
|
||||
float mA[ntx][tile_C::ne/2][8];
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
@@ -1084,15 +1086,15 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
((mma_A_K8 *) A[n])[k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
||||
load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
||||
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
|
||||
@@ -1107,58 +1109,58 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
||||
float2 dB[mma_C::ne/2];
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||
float2 dB[tile_C::ne/2];
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int j = j0 + tile_C::get_j(l);
|
||||
|
||||
dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
||||
mma_B B[2];
|
||||
tile_B B[2];
|
||||
|
||||
// Here load_generic is faster than load_ldmatrix.
|
||||
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
|
||||
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
|
||||
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
|
||||
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
|
||||
|
||||
mma_C Cm[2];
|
||||
tile_C Cm[2];
|
||||
if (k01 >= WARP_SIZE * 3/4) {
|
||||
mma_A A1;
|
||||
tile_A A1;
|
||||
A1.x[0] = 0x01010101;
|
||||
A1.x[1] = 0x01010101;
|
||||
Cm[0].mma(A1, B[0]);
|
||||
Cm[1].mma(A1, B[1]);
|
||||
mma(Cm[0], A1, B[0]);
|
||||
mma(Cm[1], A1, B[1]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
mma_C Cd[2];
|
||||
tile_C Cd[2];
|
||||
|
||||
Cd[0].mma(A[n][k01/4 + 0], B[0]);
|
||||
Cd[1].mma(A[n][k01/4 + 1], B[1]);
|
||||
mma(Cd[0], A[n][k01/4 + 0], B[0]);
|
||||
mma(Cd[1], A[n][k01/4 + 1], B[1]);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
|
||||
if (k01 >= WARP_SIZE * 3/4) {
|
||||
tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
|
||||
}
|
||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
|
||||
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
|
||||
float2 sB[mma_C::ne/2];
|
||||
float2 sB[tile_C::ne/2];
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int j = j0 + tile_C::get_j(l);
|
||||
|
||||
sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
|
||||
}
|
||||
@@ -1166,9 +1168,9 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
|
||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
|
||||
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1708,15 +1710,15 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
|
||||
typedef mma_A_I16K4<int> mma_A;
|
||||
typedef mma_B_J8K4<int> mma_B;
|
||||
typedef mma_C_I16J8<int> mma_C;
|
||||
typedef tile<16, 4, int> tile_A;
|
||||
typedef tile< 8, 4, int> tile_B;
|
||||
typedef tile<16, 8, int> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = 2 * granularity;
|
||||
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
||||
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
||||
|
||||
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
||||
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
|
||||
|
||||
const int * x_qs = (const int *) x;
|
||||
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
|
||||
@@ -1724,11 +1726,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const float * y_df = (const float *) y;
|
||||
|
||||
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
||||
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
|
||||
|
||||
mma_A A[ntx][8];
|
||||
int scA[ntx][mma_C::ne/2][8];
|
||||
float dA[ntx][mma_C::ne/2];
|
||||
tile_A A[ntx][8];
|
||||
int scA[ntx][tile_C::ne/2][8];
|
||||
float dA[ntx][tile_C::ne/2];
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
@@ -1736,8 +1738,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
A[n][k01/4 + 0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
|
||||
A[n][k01/4 + 1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
|
||||
load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
|
||||
load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@@ -1745,8 +1747,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
||||
|
||||
const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
|
||||
const int8_t * sc = (const int8_t *) &sc_packed;
|
||||
@@ -1759,41 +1761,41 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
||||
|
||||
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
||||
float tmp[ntx][mma_C::ne] = {{0.0f}};
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||
float tmp[ntx][tile_C::ne] = {{0.0f}};
|
||||
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
||||
mma_B B[2];
|
||||
float dB[mma_C::ne/2];
|
||||
tile_B B[2];
|
||||
float dB[tile_C::ne/2];
|
||||
|
||||
// Here load_generic is faster than load_ldmatrix.
|
||||
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
|
||||
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
|
||||
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
|
||||
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int j = j0 + tile_C::get_j(l);
|
||||
|
||||
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
mma_C C[2];
|
||||
C[0].mma(A[n][k01/4 + 0], B[0]);
|
||||
C[1].mma(A[n][k01/4 + 1], B[1]);
|
||||
tile_C C[2];
|
||||
mma(C[0], A[n][k01/4 + 0], B[0]);
|
||||
mma(C[1], A[n][k01/4 + 1], B[1]);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
|
||||
}
|
||||
}
|
||||
@@ -1802,8 +1804,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2];
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2312,36 +2314,36 @@ template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
static __device__ __forceinline__ void mmq_write_back_mma(
|
||||
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
|
||||
|
||||
typedef mma_C_I16J8<int> mma_C;
|
||||
typedef tile<16, 8, int> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = 2 * granularity;
|
||||
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
||||
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
||||
|
||||
const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
|
||||
const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
|
||||
static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l);
|
||||
|
||||
if (j > j_max) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int i = i0 + n*mma_C::I + mma_C::get_i(l);
|
||||
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
||||
|
||||
if (need_check && i > i_max) {
|
||||
continue;
|
||||
}
|
||||
|
||||
dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l];
|
||||
dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 16);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 16);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 16);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 16);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 16);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 16);
|
||||
@@ -1,10 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 32);
|
||||
@@ -1,10 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 64);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 64);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 64);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 64);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 64);
|
||||
@@ -1,10 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 1, 8);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 16, 1);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 16, 2);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 16, 4);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 2, 4);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 2, 8);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 32, 1);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 32, 2);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 4, 2);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 4, 4);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 4, 8);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 64, 1);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8, 1);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8, 2);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8, 4);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8, 8);
|
||||
@@ -18,7 +18,7 @@ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_f
|
||||
|
||||
"""
|
||||
|
||||
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {cols_per_block});\n"
|
||||
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {ncols1}, {ncols2});\n"
|
||||
|
||||
TYPES_MMQ = [
|
||||
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
|
||||
@@ -57,12 +57,18 @@ for vkq_size in [16, 32]:
|
||||
with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
|
||||
|
||||
for cols_per_block in [8, 16, 32, 64]:
|
||||
with open(f"fattn-mma-f16-instance-cpb{cols_per_block}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_MMA_START)
|
||||
for ncols in [8, 16, 32, 64, 128]:
|
||||
for ncols2 in [1, 2, 4, 8]:
|
||||
ncols1 = ncols // ncols2
|
||||
if ncols == 128:
|
||||
continue # Too much register pressure.
|
||||
with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_MMA_START)
|
||||
|
||||
for head_size in [64, 80, 96, 112, 128, 256]:
|
||||
f.write(SOURCE_FATTN_MMA_CASE.format(cols_per_block=cols_per_block, head_size=head_size))
|
||||
for head_size in [64, 80, 96, 112, 128, 256]:
|
||||
if ncols == 128 and head_size == 256:
|
||||
continue # Needs too much shared memory.
|
||||
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size=head_size))
|
||||
|
||||
for type in TYPES_MMQ:
|
||||
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
|
||||
|
||||
@@ -107,6 +107,10 @@ if (GGML_HIP_NO_VMM)
|
||||
add_compile_definitions(GGML_HIP_NO_VMM)
|
||||
endif()
|
||||
|
||||
if (NOT GGML_CUDA_FA)
|
||||
add_compile_definitions(GGML_CUDA_NO_FA)
|
||||
endif()
|
||||
|
||||
if (CXX_IS_HIPCC)
|
||||
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
|
||||
target_link_libraries(ggml-hip PRIVATE hip::device)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
#include <arm_sve.h>
|
||||
#endif // __ARM_FEATURE_SVE
|
||||
|
||||
#if defined(__ARM_NEON) && !defined(__CUDACC__)
|
||||
#if defined(__ARM_NEON) && !defined(__CUDACC__) && !defined(__MUSACC__)
|
||||
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
|
||||
//
|
||||
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
#endif
|
||||
|
||||
// create residency sets only on macOS >= 15.0
|
||||
#if TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \
|
||||
#if !TARGET_CPU_X86_64 && TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \
|
||||
TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \
|
||||
TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \
|
||||
TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000
|
||||
|
||||
@@ -49,7 +49,7 @@ if (MUSAToolkit_FOUND)
|
||||
|
||||
set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX)
|
||||
foreach(SOURCE ${GGML_SOURCES_MUSA})
|
||||
set(COMPILE_FLAGS "-x musa -mtgpu")
|
||||
set(COMPILE_FLAGS "-fsigned-char -x musa -mtgpu")
|
||||
foreach(ARCH ${MUSA_ARCHITECTURES})
|
||||
set(COMPILE_FLAGS "${COMPILE_FLAGS} --cuda-gpu-arch=mp_${ARCH}")
|
||||
endforeach()
|
||||
@@ -83,6 +83,10 @@ if (MUSAToolkit_FOUND)
|
||||
add_compile_definitions(GGML_CUDA_NO_VMM)
|
||||
endif()
|
||||
|
||||
if (NOT GGML_CUDA_FA)
|
||||
add_compile_definitions(GGML_CUDA_NO_FA)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
|
||||
add_compile_definitions(GGML_CUDA_F16)
|
||||
endif()
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
message(STATUS "GGML_SYCL_TARGET=${GGML_SYCL_TARGET}")
|
||||
|
||||
if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA|AMD)$")
|
||||
message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL, NVIDIA, or AMD")
|
||||
endif()
|
||||
|
||||
@@ -99,3 +99,20 @@ catch (sycl::exception const &exc) {
|
||||
<< ", line:" << __LINE__ << std::endl;
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
|
||||
void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams) {
|
||||
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
|
||||
for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
|
||||
if (extra->events[i][is] != nullptr) {
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::destroy_event(extra->events[i][is])));
|
||||
}
|
||||
}
|
||||
if (extra->data_device[i] != nullptr && streams.size()>0) {
|
||||
ggml_sycl_set_device(i);
|
||||
SYCL_CHECK(
|
||||
CHECK_TRY_ERROR(sycl::free(extra->data_device[i], *(streams[i]))));
|
||||
}
|
||||
}
|
||||
delete extra;
|
||||
}
|
||||
|
||||
@@ -19,6 +19,9 @@
|
||||
#include "dpct/helper.hpp"
|
||||
#include "ggml-sycl.h"
|
||||
#include "presets.hpp"
|
||||
#include "sycl_hw.hpp"
|
||||
|
||||
|
||||
#if GGML_SYCL_DNNL
|
||||
#include "dnnl.hpp"
|
||||
#include "dnnl_sycl.hpp"
|
||||
@@ -35,7 +38,10 @@
|
||||
void* ggml_sycl_host_malloc(size_t size);
|
||||
void ggml_sycl_host_free(void* ptr);
|
||||
|
||||
static int g_ggml_sycl_debug = 0;
|
||||
|
||||
extern int g_ggml_sycl_debug;
|
||||
extern int g_ggml_sycl_disable_optimize;
|
||||
|
||||
#define GGML_SYCL_DEBUG(...) \
|
||||
do { \
|
||||
if (g_ggml_sycl_debug) \
|
||||
@@ -182,18 +188,24 @@ inline dpct::err0 ggml_sycl_set_device(const int device) try {
|
||||
}
|
||||
|
||||
//////////////////////
|
||||
struct optimize_feature {
|
||||
bool reorder=false;
|
||||
};
|
||||
|
||||
struct sycl_device_info {
|
||||
int cc; // compute capability
|
||||
// int nsm; // number of streaming multiprocessors
|
||||
// size_t smpb; // max. shared memory per block
|
||||
bool vmm; // virtual memory support
|
||||
size_t total_vram;
|
||||
sycl_hw_info hw_info;
|
||||
optimize_feature opt_feature;
|
||||
};
|
||||
|
||||
|
||||
struct ggml_sycl_device_info {
|
||||
int device_count;
|
||||
|
||||
struct sycl_device_info {
|
||||
int cc; // compute capability
|
||||
// int nsm; // number of streaming multiprocessors
|
||||
// size_t smpb; // max. shared memory per block
|
||||
bool vmm; // virtual memory support
|
||||
size_t total_vram;
|
||||
};
|
||||
|
||||
sycl_device_info devices[GGML_SYCL_MAX_DEVICES] = {};
|
||||
|
||||
std::array<float, GGML_SYCL_MAX_DEVICES> default_tensor_split = {};
|
||||
@@ -260,17 +272,46 @@ struct ggml_tensor_extra_gpu {
|
||||
// tensors
|
||||
dpct::event_ptr events[GGML_SYCL_MAX_DEVICES]
|
||||
[GGML_SYCL_MAX_STREAMS]; // events for synchronizing multiple GPUs
|
||||
optimize_feature optimized_feature;
|
||||
};
|
||||
|
||||
void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams={});
|
||||
|
||||
inline optimize_feature check_gpu_optimize_feature(syclex::architecture &arch) {
|
||||
optimize_feature opt;
|
||||
|
||||
opt.reorder =
|
||||
(arch == syclex::architecture::intel_gpu_dg1 ||
|
||||
arch == syclex::architecture::intel_gpu_acm_g10 ||
|
||||
arch == syclex::architecture::intel_gpu_acm_g11 ||
|
||||
arch == syclex::architecture::intel_gpu_acm_g12 ||
|
||||
arch == syclex::architecture::intel_gpu_pvc ||
|
||||
arch == syclex::architecture::intel_gpu_pvc_vg ||
|
||||
arch == syclex::architecture::intel_gpu_mtl_u ||
|
||||
arch == syclex::architecture::intel_gpu_mtl_s ||
|
||||
arch == syclex::architecture::intel_gpu_mtl_h ||
|
||||
arch == syclex::architecture::intel_gpu_arl_u ||
|
||||
arch == syclex::architecture::intel_gpu_arl_s ||
|
||||
arch == syclex::architecture::intel_gpu_arl_h ||
|
||||
arch == syclex::architecture::intel_gpu_bmg_g21 ||
|
||||
arch == syclex::architecture::intel_gpu_lnl_m
|
||||
);
|
||||
|
||||
return opt;
|
||||
}
|
||||
|
||||
struct ggml_backend_sycl_context {
|
||||
int device;
|
||||
std::string name;
|
||||
optimize_feature opt_feature;
|
||||
bool optimized_graph=false;
|
||||
|
||||
queue_ptr qptrs[GGML_SYCL_MAX_DEVICES][GGML_SYCL_MAX_STREAMS] = { { nullptr } };
|
||||
|
||||
explicit ggml_backend_sycl_context(int device) :
|
||||
device(device),
|
||||
name(GGML_SYCL_NAME + std::to_string(device)) {
|
||||
opt_feature = ggml_sycl_info().devices[device].opt_feature;
|
||||
}
|
||||
|
||||
queue_ptr stream(int device, int stream) {
|
||||
@@ -680,5 +721,4 @@ bool gpu_has_xmx(sycl::device &dev);
|
||||
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const ggml_sycl_op_flatten_t op);
|
||||
|
||||
#endif // GGML_SYCL_COMMON_HPP
|
||||
|
||||
@@ -125,6 +125,25 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
int constexpr WARP_K = WARP_SIZE * QK4_0;
|
||||
const int n_warp = (k + WARP_K - 1) / WARP_K;
|
||||
GGML_ASSERT(k % 2 == 0);
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
|
||||
sycl::range<3>(1, 1, WARP_SIZE),
|
||||
sycl::range<3>(1, 1, WARP_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{
|
||||
dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
@@ -452,10 +471,15 @@ static void convert_unary_sycl(const void *__restrict__ vx,
|
||||
}
|
||||
}
|
||||
|
||||
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) {
|
||||
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
return dequantize_block_sycl<QK4_0, QR4_0, dequantize_q4_0>;
|
||||
if (dst->src[0]->extra &&
|
||||
((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
return dequantize_row_q4_0_sycl_reorder;
|
||||
} else {
|
||||
return dequantize_block_sycl<QK4_0, QR4_0, dequantize_q4_0>;
|
||||
}
|
||||
case GGML_TYPE_Q4_1:
|
||||
return dequantize_block_sycl<QK4_1, QR4_1, dequantize_q4_1>;
|
||||
case GGML_TYPE_Q5_0:
|
||||
@@ -499,10 +523,15 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) {
|
||||
}
|
||||
}
|
||||
|
||||
to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) {
|
||||
to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
return dequantize_row_q4_0_sycl;
|
||||
if (dst->src[0]->extra &&
|
||||
((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
return dequantize_row_q4_0_sycl_reorder;
|
||||
} else {
|
||||
return dequantize_row_q4_0_sycl;
|
||||
}
|
||||
case GGML_TYPE_Q4_1:
|
||||
return dequantize_row_q4_1_sycl;
|
||||
case GGML_TYPE_Q5_0:
|
||||
|
||||
@@ -21,7 +21,7 @@ using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
|
||||
typedef to_t_sycl_t<float> to_fp32_sycl_t;
|
||||
typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
|
||||
|
||||
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type);
|
||||
to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type);
|
||||
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst);
|
||||
to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst);
|
||||
|
||||
#endif // GGML_SYCL_CONVERT_HPP
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
#include "common.hpp"
|
||||
|
||||
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
|
||||
typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs,
|
||||
const int iqs, dfloat2 &v);
|
||||
|
||||
static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
|
||||
const int iqs, dfloat2 &v) {
|
||||
@@ -40,6 +42,29 @@ static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
|
||||
#endif // GGML_SYCL_F16
|
||||
}
|
||||
|
||||
static __dpct_inline__ void dequantize_q4_0_reorder(const void *d_ptr, const int64_t ib, const void *qs,
|
||||
const int iqs, dfloat2 &v) {
|
||||
// const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
|
||||
const dfloat d = (const dfloat)*((const sycl::half*)d_ptr+ib);
|
||||
|
||||
const int vui = *((const uint8_t *)qs+iqs);
|
||||
|
||||
v.x() = vui & 0xF;
|
||||
v.y() = vui >> 4;
|
||||
|
||||
#ifdef GGML_SYCL_F16
|
||||
// v = v - {8.0f, 8.0f};
|
||||
// v = v * {d, d};
|
||||
v.s0() = (v.s0() - 8.0f) * d;
|
||||
v.s1() = (v.s1() - 8.0f) * d;
|
||||
|
||||
#else
|
||||
v.x() = (v.x() - 8.0f) * d;
|
||||
v.y() = (v.y() - 8.0f) * d;
|
||||
#endif // GGML_SYCL_F16
|
||||
}
|
||||
|
||||
static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib,
|
||||
const int iqs, dfloat2 &v) {
|
||||
const block_q4_1 * x = (const block_q4_1 *) vx;
|
||||
@@ -167,6 +192,36 @@ static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restri
|
||||
}
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_block_q4_0_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
auto k=nb32;
|
||||
// assume 32 threads
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const int lane_ib = i * WARP_SIZE + tid;
|
||||
|
||||
if (lane_ib >= k / QK4_0) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst_t * y_ptr = yy + lane_ib * QK4_0;
|
||||
|
||||
auto qs = (const uint8_t*)vx + lane_ib * QK4_0 / 2;
|
||||
auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k / 2) + lane_ib;
|
||||
|
||||
const float d = float(*s_ptr);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < QK4_0 / 2; ++l) {
|
||||
int vq = qs[l];
|
||||
y_ptr[l + 0] = d * ((vq & 0xF) - 8);
|
||||
y_ptr[l + 16] = d * ((vq >> 4) - 8);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
#include "dequantize.hpp"
|
||||
#include "presets.hpp"
|
||||
|
||||
|
||||
static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
||||
const sycl::half *x = (const sycl::half *)vx;
|
||||
|
||||
@@ -91,6 +90,112 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
|
||||
}
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t_reorder dequantize_kernel_reorder>
|
||||
static void dequantize_mul_mat_vec_reorder(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
// qk = quantized weights per x block
|
||||
// qr = number of quantized weights per data value in x block
|
||||
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
||||
item_ct1.get_local_id(1);
|
||||
|
||||
if (row >= nrows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
|
||||
|
||||
const int ncols_left = ncols % (QK4_0*WARP_SIZE);
|
||||
const int ncols_align = ncols - ncols_left;
|
||||
const int iter_stride = 8*2*GGML_SYCL_DMMV_X;
|
||||
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter //64/16=4, 512/16/2= 16
|
||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
// partial sum for each thread
|
||||
#ifdef GGML_SYCL_F16
|
||||
sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
|
||||
#else
|
||||
float tmp = 0.0f;
|
||||
#endif // GGML_SYCL_F16
|
||||
const char *d_ptr = (const char*)vx+ncols*nrows/2;
|
||||
int i=0;
|
||||
for (i = 0; i < ncols_align; i += iter_stride) {
|
||||
const int col = i + vals_per_iter*tid;
|
||||
const int ib = (row*ncols + col)/qk; // x block index
|
||||
const int iqs = (col%qk)/qr; // x quant index
|
||||
const int iybs = col - col%qk; // y block start index
|
||||
|
||||
// processing >2 values per i iter is faster for fast GPUs
|
||||
#pragma unroll
|
||||
for (int j = 0; j < vals_per_iter; j += 2) {
|
||||
// process 2 vals per j iter
|
||||
|
||||
// dequantize
|
||||
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
|
||||
dfloat2 v;
|
||||
dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v);
|
||||
|
||||
// matrix multiplication
|
||||
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
|
||||
#ifdef GGML_SYCL_F16
|
||||
dfloat2 t1{y[iybs + iqs + j / qr + 0],
|
||||
y[iybs + iqs + j / qr + y_offset]};
|
||||
|
||||
tmp += v * t1;
|
||||
#else
|
||||
tmp += v.x() * y[iybs + iqs + j / qr + 0];
|
||||
tmp += v.y() * y[iybs + iqs + j / qr + y_offset];
|
||||
#endif // GGML_SYCL_F16
|
||||
}
|
||||
}
|
||||
|
||||
for (; i < ncols; i += iter_stride) {
|
||||
if (tid>=ncols_left/QK4_0) continue;
|
||||
const int col = i + vals_per_iter*tid;
|
||||
const int ib = (row*ncols + col)/qk; // x block index
|
||||
const int iqs = (col%qk)/qr; // x quant index
|
||||
const int iybs = col - col%qk; // y block start index
|
||||
|
||||
// processing >2 values per i iter is faster for fast GPUs
|
||||
#pragma unroll
|
||||
for (int j = 0; j < vals_per_iter; j += 2) {
|
||||
// process 2 vals per j iter
|
||||
|
||||
// dequantize
|
||||
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
|
||||
dfloat2 v;
|
||||
dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v);
|
||||
|
||||
// matrix multiplication
|
||||
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
|
||||
#ifdef GGML_SYCL_F16
|
||||
dfloat2 t1{y[iybs + iqs + j / qr + 0],
|
||||
y[iybs + iqs + j / qr + y_offset]};
|
||||
|
||||
tmp += v * t1;
|
||||
#else
|
||||
tmp += v.x() * y[iybs + iqs + j / qr + 0];
|
||||
tmp += v.y() * y[iybs + iqs + j / qr + y_offset];
|
||||
#endif // GGML_SYCL_F16
|
||||
}
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
|
||||
for (int mask = mask_start; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
#ifdef GGML_SYCL_F16
|
||||
dst[row] = tmp.x() + tmp.y();
|
||||
#else
|
||||
dst[row] = tmp;
|
||||
#endif // GGML_SYCL_F16
|
||||
}
|
||||
}
|
||||
|
||||
static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
|
||||
float *dst, const int ncols,
|
||||
const int nrows,
|
||||
@@ -759,6 +864,28 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
|
||||
}
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloat *y,
|
||||
float *dst, const int ncols,
|
||||
const int nrows,
|
||||
dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
|
||||
float *dst, const int ncols,
|
||||
@@ -953,7 +1080,6 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t row_diff = row_high - row_low;
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
|
||||
#ifdef GGML_SYCL_F16
|
||||
@@ -967,7 +1093,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
||||
|
||||
if (src1_convert_f16) {
|
||||
src1_dfloat = src1_dfloat_a.alloc(ne00);
|
||||
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
|
||||
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
|
||||
GGML_ASSERT(to_fp16_sycl != nullptr);
|
||||
to_fp16_sycl(src1_ddf_i, src1_dfloat, ne00, stream);
|
||||
}
|
||||
@@ -977,7 +1103,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||
if ((ggml_tensor_extra_gpu*)dst->src[0]->extra &&
|
||||
((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
dequantize_mul_mat_vec_q4_0_sycl_reorder(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||
} else {
|
||||
dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||
}
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||
@@ -1020,4 +1151,5 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
||||
GGML_UNUSED(src1_ddq_i);
|
||||
GGML_UNUSED(src1_ncols);
|
||||
GGML_UNUSED(src1_padded_row_size);
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
308
ggml/src/ggml-sycl/getrows.cpp
Normal file
308
ggml/src/ggml-sycl/getrows.cpp
Normal file
@@ -0,0 +1,308 @@
|
||||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "common.hpp"
|
||||
#include "dequantize.hpp"
|
||||
#include "getrows.hpp"
|
||||
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static void k_get_rows(
|
||||
const void * src0, const int32_t * src1, dst_t * dst,
|
||||
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
|
||||
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
|
||||
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
|
||||
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
|
||||
size_t s10, size_t s11, size_t s12,
|
||||
const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
|
||||
|
||||
const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
|
||||
item_ct1.get_local_id(2)) *
|
||||
2;
|
||||
const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
item_ct1.get_local_id(1);
|
||||
const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
|
||||
item_ct1.get_local_id(0)) /
|
||||
ne12;
|
||||
const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
|
||||
item_ct1.get_local_id(0)) %
|
||||
ne12;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
|
||||
const int ib = i00/qk; // block index
|
||||
const int iqs = (i00%qk)/qr; // quant index
|
||||
const int iybs = i00 - i00%qk; // dst block start index
|
||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
// dequantize
|
||||
dfloat2 v;
|
||||
dequantize_kernel(src0_row, ib, iqs, v);
|
||||
|
||||
dst_row[iybs + iqs + 0] = v.x();
|
||||
dst_row[iybs + iqs + y_offset] = v.y();
|
||||
}
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t_reorder dequantize_kernel_recorder, typename dst_t>
|
||||
static void k_get_rows_reorder(
|
||||
const void * src0, const void *src0_dq, const int32_t * src1, dst_t * dst,
|
||||
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
|
||||
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
|
||||
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
|
||||
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
|
||||
size_t s10, size_t s11, size_t s12,
|
||||
const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
|
||||
|
||||
const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
|
||||
item_ct1.get_local_id(2)) *
|
||||
2;
|
||||
const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
item_ct1.get_local_id(1);
|
||||
const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
|
||||
item_ct1.get_local_id(0)) /
|
||||
ne12;
|
||||
const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
|
||||
item_ct1.get_local_id(0)) %
|
||||
ne12;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
}
|
||||
auto ncols = ne00;
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
|
||||
const int src0_off = i01 * ncols + i00;
|
||||
const int ib = src0_off / QK4_0; // block index
|
||||
const int iqs = (i00%qk)/qr; // x quant index
|
||||
const int iybs = i00 - i00%qk; // dst block start index
|
||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
// dequantize
|
||||
dfloat2 v;
|
||||
dequantize_kernel_recorder((const void *)src0_dq, ib, (const void *)src0, src0_off/2, v);
|
||||
|
||||
dst_row[iybs + iqs + 0] = v.x();
|
||||
dst_row[iybs + iqs + y_offset] = v.y();
|
||||
|
||||
GGML_UNUSED(nb01);
|
||||
GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03);
|
||||
}
|
||||
|
||||
template<typename src0_t, typename dst_t>
|
||||
static void k_get_rows_float(
|
||||
const src0_t * src0, const int32_t * src1, dst_t * dst,
|
||||
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
|
||||
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
|
||||
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
|
||||
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
|
||||
size_t s10, size_t s11, size_t s12,
|
||||
const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
|
||||
|
||||
const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
item_ct1.get_local_id(1);
|
||||
const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
|
||||
item_ct1.get_local_id(0)) /
|
||||
ne12;
|
||||
const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
|
||||
item_ct1.get_local_id(0)) %
|
||||
ne12;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
|
||||
|
||||
dst_row[i00] = src0_row[i00];
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dq>
|
||||
static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const void *src0_dd,
|
||||
const int32_t *src1_dd, float *dst_dd,
|
||||
queue_ptr stream) {
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
|
||||
const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
|
||||
const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
|
||||
|
||||
// strides in elements
|
||||
//const size_t s0 = nb0 / ggml_element_size(dst);
|
||||
const size_t s1 = nb1 / ggml_element_size(dst);
|
||||
const size_t s2 = nb2 / ggml_element_size(dst);
|
||||
const size_t s3 = nb3 / ggml_element_size(dst);
|
||||
|
||||
const size_t s10 = nb10 / ggml_element_size(src1);
|
||||
const size_t s11 = nb11 / ggml_element_size(src1);
|
||||
const size_t s12 = nb12 / ggml_element_size(src1);
|
||||
//const size_t s13 = nb13 / ggml_element_size(src1);
|
||||
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_get_rows<qk, qr, dq>(
|
||||
src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
||||
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
||||
});
|
||||
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t_reorder dq_reorder>
|
||||
static void get_rows_sycl_reorder(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const void *src0_dd,
|
||||
const int32_t *src1_dd, float *dst_dd,
|
||||
queue_ptr stream) {
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
|
||||
const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
|
||||
const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
|
||||
|
||||
// strides in elements
|
||||
//const size_t s0 = nb0 / ggml_element_size(dst);
|
||||
const size_t s1 = nb1 / ggml_element_size(dst);
|
||||
const size_t s2 = nb2 / ggml_element_size(dst);
|
||||
const size_t s3 = nb3 / ggml_element_size(dst);
|
||||
|
||||
const size_t s10 = nb10 / ggml_element_size(src1);
|
||||
const size_t s11 = nb11 / ggml_element_size(src1);
|
||||
const size_t s12 = nb12 / ggml_element_size(src1);
|
||||
//const size_t s13 = nb13 / ggml_element_size(src1);
|
||||
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
|
||||
const uint8_t* src0_q = (const uint8_t*)src0_dd;
|
||||
const size_t ncols = ne00;
|
||||
const size_t nrows = ne01;
|
||||
const sycl::half* src0_dq = (const sycl::half*)(src0_q + nrows * ncols / 2);
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{
|
||||
k_get_rows_reorder<qk, qr, dq_reorder>(
|
||||
src0_dd, src0_dq, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
||||
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
||||
});
|
||||
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
|
||||
template <typename src0_t>
|
||||
static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const src0_t *src0_dd, const int32_t *src1_dd,
|
||||
float *dst_dd, queue_ptr stream) {
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
|
||||
const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE;
|
||||
const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
|
||||
|
||||
// strides in elements
|
||||
//const size_t s0 = nb0 / ggml_element_size(dst);
|
||||
const size_t s1 = nb1 / ggml_element_size(dst);
|
||||
const size_t s2 = nb2 / ggml_element_size(dst);
|
||||
const size_t s3 = nb3 / ggml_element_size(dst);
|
||||
|
||||
const size_t s10 = nb10 / ggml_element_size(src1);
|
||||
const size_t s11 = nb11 / ggml_element_size(src1);
|
||||
const size_t s12 = nb12 / ggml_element_size(src1);
|
||||
//const size_t s13 = nb13 / ggml_element_size(src1);
|
||||
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
||||
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_d, const float *src1_d,
|
||||
float *dst_d, const queue_ptr &stream) {
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
||||
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
|
||||
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
|
||||
|
||||
const int32_t * src1_i32 = (const int32_t *) src1_d;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F16:
|
||||
get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
|
||||
src1_i32, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_F32:
|
||||
get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
if (ctx.opt_feature.reorder && dst->op == GGML_OP_MUL_MAT) {
|
||||
get_rows_sycl_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
} else {
|
||||
get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
}
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
default:
|
||||
// TODO: k-quants
|
||||
GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
23
ggml/src/ggml-sycl/getrows.hpp
Normal file
23
ggml/src/ggml-sycl/getrows.hpp
Normal file
@@ -0,0 +1,23 @@
|
||||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
#ifndef GGML_SYCL_GETROWS_HPP
|
||||
#define GGML_SYCL_GETROWS_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_d, const float *src1_d,
|
||||
float *dst_d, const queue_ptr &stream);
|
||||
|
||||
#endif // GGML_SYCL_GETROWS_HPP
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user