mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-23 16:37:33 +03:00
Compare commits
2 Commits
b4689
...
0cc4m/vulk
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
25840747e6 | ||
|
|
7037e94852 |
@@ -674,7 +674,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"--no-context-shift"},
|
||||
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
|
||||
string_format("disables context shift on inifinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
|
||||
[](common_params & params) {
|
||||
params.ctx_shift = false;
|
||||
}
|
||||
|
||||
@@ -249,30 +249,16 @@ class chat_template {
|
||||
inputs.add_generation_prompt = false;
|
||||
full = apply(inputs);
|
||||
}
|
||||
auto eos_pos_last = full.rfind(eos_token_);
|
||||
if (eos_pos_last == prefix.size() - eos_token_.size() ||
|
||||
(full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
|
||||
full = full.substr(0, eos_pos_last);
|
||||
}
|
||||
size_t common_prefix_length = 0;
|
||||
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
|
||||
if (prefix[i] != full[i]) {
|
||||
break;
|
||||
|
||||
if (full.find(prefix) != 0) {
|
||||
if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) {
|
||||
prefix = prefix.substr(0, prefix.size() - eos_token_.size());
|
||||
}
|
||||
if (prefix[i] == '<') {
|
||||
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
|
||||
// but it removes thinking tags for past messages.
|
||||
// The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
|
||||
continue;
|
||||
}
|
||||
common_prefix_length = i + 1;
|
||||
}
|
||||
auto example = full.substr(common_prefix_length);
|
||||
if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
|
||||
if (full.find(prefix) != 0) {
|
||||
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
|
||||
} else {
|
||||
tool_call_example_ = example;
|
||||
}
|
||||
tool_call_example_ = full.substr(prefix.size());
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
|
||||
@@ -377,7 +363,7 @@ class chat_template {
|
||||
if (polyfill_tools) {
|
||||
adjusted_messages = add_system(inputs.messages,
|
||||
"You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
|
||||
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
|
||||
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_));
|
||||
} else {
|
||||
adjusted_messages = inputs.messages;
|
||||
}
|
||||
|
||||
@@ -1385,13 +1385,6 @@ static std::string strip(const std::string & s) {
|
||||
return s.substr(start, end - start + 1);
|
||||
}
|
||||
|
||||
static std::string capitalize(const std::string & s) {
|
||||
if (s.empty()) return s;
|
||||
auto result = s;
|
||||
result[0] = std::toupper(result[0]);
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::string html_escape(const std::string & s) {
|
||||
std::string result;
|
||||
result.reserve(s.size());
|
||||
@@ -1469,9 +1462,6 @@ public:
|
||||
if (method->get_name() == "strip") {
|
||||
vargs.expectArgs("strip method", {0, 0}, {0, 0});
|
||||
return Value(strip(str));
|
||||
} else if (method->get_name() == "capitalize") {
|
||||
vargs.expectArgs("capitalize method", {0, 0}, {0, 0});
|
||||
return Value(capitalize(str));
|
||||
} else if (method->get_name() == "endswith") {
|
||||
vargs.expectArgs("endswith method", {1, 1}, {0, 0});
|
||||
auto suffix = vargs.args[0].get<std::string>();
|
||||
@@ -1802,7 +1792,7 @@ private:
|
||||
auto left = parseStringConcat();
|
||||
if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression");
|
||||
|
||||
static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not\s+in\b)");
|
||||
static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not[\r\n\s]+in\b)");
|
||||
static std::regex not_tok(R"(not\b)");
|
||||
std::string op_str;
|
||||
while (!(op_str = consumeToken(compare_tok)).empty()) {
|
||||
@@ -2181,7 +2171,7 @@ private:
|
||||
using TemplateTokenIterator = TemplateTokenVector::const_iterator;
|
||||
|
||||
std::vector<std::string> parseVarNames() {
|
||||
static std::regex varnames_regex(R"(((?:\w+)(?:\s*,\s*(?:\w+))*)\s*)");
|
||||
static std::regex varnames_regex(R"(((?:\w+)(?:[\r\n\s]*,[\r\n\s]*(?:\w+))*)[\r\n\s]*)");
|
||||
|
||||
std::vector<std::string> group;
|
||||
if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names");
|
||||
@@ -2204,13 +2194,13 @@ private:
|
||||
}
|
||||
|
||||
TemplateTokenVector tokenize() {
|
||||
static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})");
|
||||
static std::regex comment_tok(R"(\{#([-~]?)([\s\S\r\n]*?)([-~]?)#\})");
|
||||
static std::regex expr_open_regex(R"(\{\{([-~])?)");
|
||||
static std::regex block_open_regex(R"(^\{%([-~])?\s*)");
|
||||
static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)");
|
||||
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
|
||||
static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
|
||||
static std::regex expr_close_regex(R"(\s*([-~])?\}\})");
|
||||
static std::regex block_close_regex(R"(\s*([-~])?%\})");
|
||||
static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})");
|
||||
static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})");
|
||||
|
||||
TemplateTokenVector tokens;
|
||||
std::vector<std::string> group;
|
||||
@@ -2294,7 +2284,7 @@ private:
|
||||
auto post_space = parseBlockClose();
|
||||
tokens.push_back(std::make_unique<EndGenerationTemplateToken>(location, pre_space, post_space));
|
||||
} else if (keyword == "set") {
|
||||
static std::regex namespaced_var_regex(R"((\w+)\s*\.\s*(\w+))");
|
||||
static std::regex namespaced_var_regex(R"((\w+)[\s\n\r]*\.[\s\n\r]*(\w+))");
|
||||
|
||||
std::string ns;
|
||||
std::vector<std::string> var_names;
|
||||
@@ -2346,11 +2336,6 @@ private:
|
||||
throw std::runtime_error("Unexpected block: " + keyword);
|
||||
}
|
||||
} else if (std::regex_search(it, end, match, non_text_open_regex)) {
|
||||
if (!match.position()) {
|
||||
if (match[0] != "{#")
|
||||
throw std::runtime_error("Internal error: Expected a comment");
|
||||
throw std::runtime_error("Missing end of comment tag");
|
||||
}
|
||||
auto text_end = it + match.position();
|
||||
text = std::string(it, text_end);
|
||||
it = text_end;
|
||||
@@ -2415,7 +2400,7 @@ private:
|
||||
|
||||
auto text = text_token->text;
|
||||
if (post_space == SpaceHandling::Strip) {
|
||||
static std::regex trailing_space_regex(R"(\s+$)");
|
||||
static std::regex trailing_space_regex(R"((\s|\r|\n)+$)");
|
||||
text = std::regex_replace(text, trailing_space_regex, "");
|
||||
} else if (options.lstrip_blocks && it != end) {
|
||||
auto i = text.size();
|
||||
@@ -2425,7 +2410,7 @@ private:
|
||||
}
|
||||
}
|
||||
if (pre_space == SpaceHandling::Strip) {
|
||||
static std::regex leading_space_regex(R"(^\s+)");
|
||||
static std::regex leading_space_regex(R"(^(\s|\r|\n)+)");
|
||||
text = std::regex_replace(text, leading_space_regex, "");
|
||||
} else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast<ExpressionTemplateToken*>((*(it - 2)).get())) {
|
||||
if (text.length() > 0 && text[0] == '\n') {
|
||||
|
||||
@@ -9,7 +9,7 @@ struct common_speculative_params {
|
||||
int n_draft = 16; // max drafted tokens
|
||||
int n_reuse = 256;
|
||||
|
||||
float p_min = 0.9f; // min probability required to accept a token in the draft
|
||||
float p_min = 0.9f; // min probabiliy required to accept a token in the draft
|
||||
};
|
||||
|
||||
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
|
||||
|
||||
@@ -37,7 +37,7 @@ Once downloaded, place your model in the models folder in llama.cpp.
|
||||
|
||||
##### Infinite text from a starting prompt (you can use `Ctrl-C` to stop it):
|
||||
```bash
|
||||
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1
|
||||
./llama-cli -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1
|
||||
```
|
||||
|
||||
### Windows:
|
||||
|
||||
Binary file not shown.
@@ -1600,10 +1600,6 @@ struct server_queue {
|
||||
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (!running) {
|
||||
QUE_DBG("%s", "terminate\n");
|
||||
return;
|
||||
}
|
||||
if (queue_tasks.empty()) {
|
||||
lock.unlock();
|
||||
break;
|
||||
@@ -1624,11 +1620,11 @@ struct server_queue {
|
||||
QUE_DBG("%s", "waiting for new tasks\n");
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (!running) {
|
||||
QUE_DBG("%s", "terminate\n");
|
||||
return;
|
||||
}
|
||||
if (queue_tasks.empty()) {
|
||||
if (!running) {
|
||||
QUE_DBG("%s", "terminate\n");
|
||||
return;
|
||||
}
|
||||
condition_tasks.wait(lock, [&]{
|
||||
return (!queue_tasks.empty() || !running);
|
||||
});
|
||||
@@ -2279,7 +2275,7 @@ struct server_context {
|
||||
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
|
||||
result.probs.push_back({
|
||||
cur_p->data[i].id,
|
||||
common_token_to_piece(ctx, cur_p->data[i].id, special),
|
||||
common_detokenize(ctx, {cur_p->data[i].id}, special),
|
||||
cur_p->data[i].p
|
||||
});
|
||||
}
|
||||
@@ -2301,7 +2297,7 @@ struct server_context {
|
||||
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
|
||||
result.probs.push_back({
|
||||
cur[i].id,
|
||||
common_token_to_piece(ctx, cur[i].id, special),
|
||||
common_detokenize(ctx, {cur[i].id}, special),
|
||||
cur[i].p
|
||||
});
|
||||
}
|
||||
@@ -4434,7 +4430,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// clean up function, to be called before exit
|
||||
auto clean_up = [&svr]() {
|
||||
SRV_INF("%s: cleaning up before exit...\n", __func__);
|
||||
svr->stop();
|
||||
llama_backend_free();
|
||||
};
|
||||
@@ -4451,6 +4446,10 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
if (!was_bound) {
|
||||
//LOG_ERROR("couldn't bind HTTP server socket", {
|
||||
// {"hostname", params.hostname},
|
||||
// {"port", params.port},
|
||||
//});
|
||||
LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port);
|
||||
clean_up();
|
||||
return 1;
|
||||
@@ -4467,7 +4466,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
if (!ctx_server.load_model(params)) {
|
||||
clean_up();
|
||||
// t.join(); // FIXME: see below
|
||||
t.join();
|
||||
LOG_ERR("%s: exiting due to model loading error\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@@ -4491,10 +4490,13 @@ int main(int argc, char ** argv) {
|
||||
});
|
||||
|
||||
shutdown_handler = [&](int) {
|
||||
// this will unblock start_loop()
|
||||
ctx_server.queue_tasks.terminate();
|
||||
};
|
||||
|
||||
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
|
||||
|
||||
ctx_server.queue_tasks.start_loop();
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = signal_handler;
|
||||
@@ -4509,13 +4511,8 @@ int main(int argc, char ** argv) {
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
#endif
|
||||
|
||||
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
|
||||
|
||||
// this call blocks the main thread until queue_tasks.terminate() is called
|
||||
ctx_server.queue_tasks.start_loop();
|
||||
|
||||
clean_up();
|
||||
// t.join(); // FIXME: http thread may stuck if there is an on-going request. we don't need to care about this for now as the HTTP connection will already be closed at this point, but it's better to fix this
|
||||
t.join();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
7
examples/server/webui/package-lock.json
generated
7
examples/server/webui/package-lock.json
generated
@@ -13,7 +13,6 @@
|
||||
"@vscode/markdown-it-katex": "^1.1.1",
|
||||
"autoprefixer": "^10.4.20",
|
||||
"daisyui": "^4.12.14",
|
||||
"dexie": "^4.0.11",
|
||||
"highlight.js": "^11.10.0",
|
||||
"katex": "^0.16.15",
|
||||
"postcss": "^8.4.49",
|
||||
@@ -2339,12 +2338,6 @@
|
||||
"url": "https://github.com/sponsors/wooorm"
|
||||
}
|
||||
},
|
||||
"node_modules/dexie": {
|
||||
"version": "4.0.11",
|
||||
"resolved": "https://registry.npmjs.org/dexie/-/dexie-4.0.11.tgz",
|
||||
"integrity": "sha512-SOKO002EqlvBYYKQSew3iymBoN2EQ4BDw/3yprjh7kAfFzjBYkaMNa/pZvcA7HSWlcKSQb9XhPe3wKyQ0x4A8A==",
|
||||
"license": "Apache-2.0"
|
||||
},
|
||||
"node_modules/didyoumean": {
|
||||
"version": "1.2.2",
|
||||
"resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz",
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
"@vscode/markdown-it-katex": "^1.1.1",
|
||||
"autoprefixer": "^10.4.20",
|
||||
"daisyui": "^4.12.14",
|
||||
"dexie": "^4.0.11",
|
||||
"highlight.js": "^11.10.0",
|
||||
"katex": "^0.16.15",
|
||||
"postcss": "^8.4.49",
|
||||
|
||||
@@ -3,7 +3,6 @@ import { useAppContext } from '../utils/app.context';
|
||||
import { Message, PendingMessage } from '../utils/types';
|
||||
import { classNames } from '../utils/misc';
|
||||
import MarkdownDisplay, { CopyButton } from './MarkdownDisplay';
|
||||
import { ChevronLeftIcon, ChevronRightIcon } from '@heroicons/react/24/outline';
|
||||
|
||||
interface SplitMessage {
|
||||
content: PendingMessage['content'];
|
||||
@@ -13,24 +12,17 @@ interface SplitMessage {
|
||||
|
||||
export default function ChatMessage({
|
||||
msg,
|
||||
siblingLeafNodeIds,
|
||||
siblingCurrIdx,
|
||||
id,
|
||||
onRegenerateMessage,
|
||||
onEditMessage,
|
||||
onChangeSibling,
|
||||
scrollToBottom,
|
||||
isPending,
|
||||
}: {
|
||||
msg: Message | PendingMessage;
|
||||
siblingLeafNodeIds: Message['id'][];
|
||||
siblingCurrIdx: number;
|
||||
id?: string;
|
||||
onRegenerateMessage(msg: Message): void;
|
||||
onEditMessage(msg: Message, content: string): void;
|
||||
onChangeSibling(sibling: Message['id']): void;
|
||||
scrollToBottom: (requiresNearBottom: boolean) => void;
|
||||
isPending?: boolean;
|
||||
}) {
|
||||
const { viewingChat, config } = useAppContext();
|
||||
const { viewingConversation, replaceMessageAndGenerate, config } =
|
||||
useAppContext();
|
||||
const [editingContent, setEditingContent] = useState<string | null>(null);
|
||||
const timings = useMemo(
|
||||
() =>
|
||||
@@ -45,8 +37,6 @@ export default function ChatMessage({
|
||||
: null,
|
||||
[msg.timings]
|
||||
);
|
||||
const nextSibling = siblingLeafNodeIds[siblingCurrIdx + 1];
|
||||
const prevSibling = siblingLeafNodeIds[siblingCurrIdx - 1];
|
||||
|
||||
// for reasoning model, we split the message into content and thought
|
||||
// TODO: implement this as remark/rehype plugin in the future
|
||||
@@ -74,7 +64,13 @@ export default function ChatMessage({
|
||||
return { content: actualContent, thought, isThinking };
|
||||
}, [msg]);
|
||||
|
||||
if (!viewingChat) return null;
|
||||
if (!viewingConversation) return null;
|
||||
|
||||
const regenerate = async () => {
|
||||
replaceMessageAndGenerate(viewingConversation.id, msg.id, undefined, () =>
|
||||
scrollToBottom(true)
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="group" id={id}>
|
||||
@@ -109,12 +105,13 @@ export default function ChatMessage({
|
||||
</button>
|
||||
<button
|
||||
className="btn mt-2"
|
||||
onClick={() => {
|
||||
if (msg.content !== null) {
|
||||
setEditingContent(null);
|
||||
onEditMessage(msg as Message, editingContent);
|
||||
}
|
||||
}}
|
||||
onClick={() =>
|
||||
replaceMessageAndGenerate(
|
||||
viewingConversation.id,
|
||||
msg.id,
|
||||
editingContent
|
||||
)
|
||||
}
|
||||
>
|
||||
Submit
|
||||
</button>
|
||||
@@ -199,35 +196,10 @@ export default function ChatMessage({
|
||||
{msg.content !== null && (
|
||||
<div
|
||||
className={classNames({
|
||||
'flex items-center gap-2 mx-4 mt-2 mb-2': true,
|
||||
'flex-row-reverse': msg.role === 'user',
|
||||
'mx-4 mt-2 mb-2': true,
|
||||
'text-right': msg.role === 'user',
|
||||
})}
|
||||
>
|
||||
{siblingLeafNodeIds && siblingLeafNodeIds.length > 1 && (
|
||||
<div className="flex gap-1 items-center opacity-60 text-sm">
|
||||
<button
|
||||
className={classNames({
|
||||
'btn btn-sm btn-ghost p-1': true,
|
||||
'opacity-20': !prevSibling,
|
||||
})}
|
||||
onClick={() => prevSibling && onChangeSibling(prevSibling)}
|
||||
>
|
||||
<ChevronLeftIcon className="h-4 w-4" />
|
||||
</button>
|
||||
<span>
|
||||
{siblingCurrIdx + 1} / {siblingLeafNodeIds.length}
|
||||
</span>
|
||||
<button
|
||||
className={classNames({
|
||||
'btn btn-sm btn-ghost p-1': true,
|
||||
'opacity-20': !nextSibling,
|
||||
})}
|
||||
onClick={() => nextSibling && onChangeSibling(nextSibling)}
|
||||
>
|
||||
<ChevronRightIcon className="h-4 w-4" />
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{/* user message */}
|
||||
{msg.role === 'user' && (
|
||||
<button
|
||||
@@ -244,11 +216,7 @@ export default function ChatMessage({
|
||||
{!isPending && (
|
||||
<button
|
||||
className="badge btn-mini show-on-hover mr-2"
|
||||
onClick={() => {
|
||||
if (msg.content !== null) {
|
||||
onRegenerateMessage(msg as Message);
|
||||
}
|
||||
}}
|
||||
onClick={regenerate}
|
||||
disabled={msg.content === null}
|
||||
>
|
||||
🔄 Regenerate
|
||||
|
||||
@@ -1,59 +1,28 @@
|
||||
import { useEffect, useMemo, useState } from 'react';
|
||||
import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
|
||||
import ChatMessage from './ChatMessage';
|
||||
import { CanvasType, Message, PendingMessage } from '../utils/types';
|
||||
import { classNames, throttle } from '../utils/misc';
|
||||
import CanvasPyInterpreter from './CanvasPyInterpreter';
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import StorageUtils from '../utils/storage';
|
||||
import { useNavigate } from 'react-router';
|
||||
import ChatMessage from './ChatMessage';
|
||||
import { CanvasType, PendingMessage } from '../utils/types';
|
||||
import { classNames } from '../utils/misc';
|
||||
import CanvasPyInterpreter from './CanvasPyInterpreter';
|
||||
|
||||
/**
|
||||
* A message display is a message node with additional information for rendering.
|
||||
* For example, siblings of the message node are stored as their last node (aka leaf node).
|
||||
*/
|
||||
export interface MessageDisplay {
|
||||
msg: Message | PendingMessage;
|
||||
siblingLeafNodeIds: Message['id'][];
|
||||
siblingCurrIdx: number;
|
||||
isPending?: boolean;
|
||||
}
|
||||
export default function ChatScreen() {
|
||||
const {
|
||||
viewingConversation,
|
||||
sendMessage,
|
||||
isGenerating,
|
||||
stopGenerating,
|
||||
pendingMessages,
|
||||
canvasData,
|
||||
} = useAppContext();
|
||||
const [inputMsg, setInputMsg] = useState('');
|
||||
const navigate = useNavigate();
|
||||
|
||||
function getListMessageDisplay(
|
||||
msgs: Readonly<Message[]>,
|
||||
leafNodeId: Message['id']
|
||||
): MessageDisplay[] {
|
||||
const currNodes = StorageUtils.filterByLeafNodeId(msgs, leafNodeId, true);
|
||||
const res: MessageDisplay[] = [];
|
||||
const nodeMap = new Map<Message['id'], Message>();
|
||||
for (const msg of msgs) {
|
||||
nodeMap.set(msg.id, msg);
|
||||
}
|
||||
// find leaf node from a message node
|
||||
const findLeafNode = (msgId: Message['id']): Message['id'] => {
|
||||
let currNode: Message | undefined = nodeMap.get(msgId);
|
||||
while (currNode) {
|
||||
if (currNode.children.length === 0) break;
|
||||
currNode = nodeMap.get(currNode.children.at(-1) ?? -1);
|
||||
}
|
||||
return currNode?.id ?? -1;
|
||||
};
|
||||
// traverse the current nodes
|
||||
for (const msg of currNodes) {
|
||||
const parentNode = nodeMap.get(msg.parent ?? -1);
|
||||
if (!parentNode) continue;
|
||||
const siblings = parentNode.children;
|
||||
if (msg.type !== 'root') {
|
||||
res.push({
|
||||
msg,
|
||||
siblingLeafNodeIds: siblings.map(findLeafNode),
|
||||
siblingCurrIdx: siblings.indexOf(msg.id),
|
||||
});
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
const currConvId = viewingConversation?.id ?? '';
|
||||
const pendingMsg: PendingMessage | undefined = pendingMessages[currConvId];
|
||||
|
||||
const scrollToBottom = throttle(
|
||||
(requiresNearBottom: boolean, delay: number = 80) => {
|
||||
const scrollToBottom = (requiresNearBottom: boolean) => {
|
||||
const mainScrollElem = document.getElementById('main-scroll');
|
||||
if (!mainScrollElem) return;
|
||||
const spaceToBottom =
|
||||
@@ -63,107 +32,36 @@ const scrollToBottom = throttle(
|
||||
if (!requiresNearBottom || spaceToBottom < 50) {
|
||||
setTimeout(
|
||||
() => mainScrollElem.scrollTo({ top: mainScrollElem.scrollHeight }),
|
||||
delay
|
||||
1
|
||||
);
|
||||
}
|
||||
},
|
||||
80
|
||||
);
|
||||
|
||||
export default function ChatScreen() {
|
||||
const {
|
||||
viewingChat,
|
||||
sendMessage,
|
||||
isGenerating,
|
||||
stopGenerating,
|
||||
pendingMessages,
|
||||
canvasData,
|
||||
replaceMessageAndGenerate,
|
||||
} = useAppContext();
|
||||
const [inputMsg, setInputMsg] = useState('');
|
||||
|
||||
// keep track of leaf node for rendering
|
||||
const [currNodeId, setCurrNodeId] = useState<number>(-1);
|
||||
const messages: MessageDisplay[] = useMemo(() => {
|
||||
if (!viewingChat) return [];
|
||||
else return getListMessageDisplay(viewingChat.messages, currNodeId);
|
||||
}, [currNodeId, viewingChat]);
|
||||
|
||||
const currConvId = viewingChat?.conv.id ?? null;
|
||||
const pendingMsg: PendingMessage | undefined =
|
||||
pendingMessages[currConvId ?? ''];
|
||||
|
||||
useEffect(() => {
|
||||
// reset to latest node when conversation changes
|
||||
setCurrNodeId(-1);
|
||||
// scroll to bottom when conversation changes
|
||||
scrollToBottom(false, 1);
|
||||
}, [currConvId]);
|
||||
|
||||
const onChunk: CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => {
|
||||
if (currLeafNodeId) {
|
||||
setCurrNodeId(currLeafNodeId);
|
||||
}
|
||||
scrollToBottom(true);
|
||||
};
|
||||
|
||||
// scroll to bottom when conversation changes
|
||||
useEffect(() => {
|
||||
scrollToBottom(false);
|
||||
}, [viewingConversation?.id]);
|
||||
|
||||
const sendNewMessage = async () => {
|
||||
if (inputMsg.trim().length === 0 || isGenerating(currConvId ?? '')) return;
|
||||
if (inputMsg.trim().length === 0 || isGenerating(currConvId)) return;
|
||||
const convId = viewingConversation?.id ?? StorageUtils.getNewConvId();
|
||||
const lastInpMsg = inputMsg;
|
||||
setInputMsg('');
|
||||
if (!viewingConversation) {
|
||||
// if user is creating a new conversation, redirect to the new conversation
|
||||
navigate(`/chat/${convId}`);
|
||||
}
|
||||
scrollToBottom(false);
|
||||
setCurrNodeId(-1);
|
||||
// get the last message node
|
||||
const lastMsgNodeId = messages.at(-1)?.msg.id ?? null;
|
||||
if (!(await sendMessage(currConvId, lastMsgNodeId, inputMsg, onChunk))) {
|
||||
// auto scroll as message is being generated
|
||||
const onChunk = () => scrollToBottom(true);
|
||||
if (!(await sendMessage(convId, inputMsg, onChunk))) {
|
||||
// restore the input message if failed
|
||||
setInputMsg(lastInpMsg);
|
||||
}
|
||||
};
|
||||
|
||||
const handleEditMessage = async (msg: Message, content: string) => {
|
||||
if (!viewingChat) return;
|
||||
setCurrNodeId(msg.id);
|
||||
scrollToBottom(false);
|
||||
await replaceMessageAndGenerate(
|
||||
viewingChat.conv.id,
|
||||
msg.parent,
|
||||
content,
|
||||
onChunk
|
||||
);
|
||||
setCurrNodeId(-1);
|
||||
scrollToBottom(false);
|
||||
};
|
||||
|
||||
const handleRegenerateMessage = async (msg: Message) => {
|
||||
if (!viewingChat) return;
|
||||
setCurrNodeId(msg.parent);
|
||||
scrollToBottom(false);
|
||||
await replaceMessageAndGenerate(
|
||||
viewingChat.conv.id,
|
||||
msg.parent,
|
||||
null,
|
||||
onChunk
|
||||
);
|
||||
setCurrNodeId(-1);
|
||||
scrollToBottom(false);
|
||||
};
|
||||
|
||||
const hasCanvas = !!canvasData;
|
||||
|
||||
// due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg)
|
||||
const pendingMsgDisplay: MessageDisplay[] =
|
||||
pendingMsg && messages.at(-1)?.msg.id !== pendingMsg.id
|
||||
? [
|
||||
{
|
||||
msg: pendingMsg,
|
||||
siblingLeafNodeIds: [],
|
||||
siblingCurrIdx: 0,
|
||||
isPending: true,
|
||||
},
|
||||
]
|
||||
: [];
|
||||
|
||||
return (
|
||||
<div
|
||||
className={classNames({
|
||||
@@ -183,19 +81,24 @@ export default function ChatScreen() {
|
||||
<div id="messages-list" className="grow">
|
||||
<div className="mt-auto flex justify-center">
|
||||
{/* placeholder to shift the message to the bottom */}
|
||||
{viewingChat ? '' : 'Send a message to start'}
|
||||
{viewingConversation ? '' : 'Send a message to start'}
|
||||
</div>
|
||||
{[...messages, ...pendingMsgDisplay].map((msg) => (
|
||||
{viewingConversation?.messages.map((msg) => (
|
||||
<ChatMessage
|
||||
key={msg.msg.id}
|
||||
msg={msg.msg}
|
||||
siblingLeafNodeIds={msg.siblingLeafNodeIds}
|
||||
siblingCurrIdx={msg.siblingCurrIdx}
|
||||
onRegenerateMessage={handleRegenerateMessage}
|
||||
onEditMessage={handleEditMessage}
|
||||
onChangeSibling={setCurrNodeId}
|
||||
key={msg.id}
|
||||
msg={msg}
|
||||
scrollToBottom={scrollToBottom}
|
||||
/>
|
||||
))}
|
||||
|
||||
{pendingMsg && (
|
||||
<ChatMessage
|
||||
msg={pendingMsg}
|
||||
scrollToBottom={scrollToBottom}
|
||||
isPending
|
||||
id="pending-msg"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* chat input */}
|
||||
@@ -215,10 +118,10 @@ export default function ChatScreen() {
|
||||
id="msg-input"
|
||||
dir="auto"
|
||||
></textarea>
|
||||
{isGenerating(currConvId ?? '') ? (
|
||||
{isGenerating(currConvId) ? (
|
||||
<button
|
||||
className="btn btn-neutral ml-2"
|
||||
onClick={() => stopGenerating(currConvId ?? '')}
|
||||
onClick={() => stopGenerating(currConvId)}
|
||||
>
|
||||
Stop
|
||||
</button>
|
||||
|
||||
@@ -25,12 +25,12 @@ export default function Header() {
|
||||
);
|
||||
}, [selectedTheme]);
|
||||
|
||||
const { isGenerating, viewingChat } = useAppContext();
|
||||
const isCurrConvGenerating = isGenerating(viewingChat?.conv.id ?? '');
|
||||
const { isGenerating, viewingConversation } = useAppContext();
|
||||
const isCurrConvGenerating = isGenerating(viewingConversation?.id ?? '');
|
||||
|
||||
const removeConversation = () => {
|
||||
if (isCurrConvGenerating || !viewingChat) return;
|
||||
const convId = viewingChat?.conv.id;
|
||||
if (isCurrConvGenerating || !viewingConversation) return;
|
||||
const convId = viewingConversation.id;
|
||||
if (window.confirm('Are you sure to delete this conversation?')) {
|
||||
StorageUtils.remove(convId);
|
||||
navigate('/');
|
||||
@@ -38,9 +38,9 @@ export default function Header() {
|
||||
};
|
||||
|
||||
const downloadConversation = () => {
|
||||
if (isCurrConvGenerating || !viewingChat) return;
|
||||
const convId = viewingChat?.conv.id;
|
||||
const conversationJson = JSON.stringify(viewingChat, null, 2);
|
||||
if (isCurrConvGenerating || !viewingConversation) return;
|
||||
const convId = viewingConversation.id;
|
||||
const conversationJson = JSON.stringify(viewingConversation, null, 2);
|
||||
const blob = new Blob([conversationJson], { type: 'application/json' });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement('a');
|
||||
@@ -75,41 +75,38 @@ export default function Header() {
|
||||
|
||||
{/* action buttons (top right) */}
|
||||
<div className="flex items-center">
|
||||
{viewingChat && (
|
||||
<div className="dropdown dropdown-end">
|
||||
{/* "..." button */}
|
||||
<button
|
||||
tabIndex={0}
|
||||
role="button"
|
||||
className="btn m-1"
|
||||
disabled={isCurrConvGenerating}
|
||||
<div v-if="messages.length > 0" className="dropdown dropdown-end">
|
||||
{/* "..." button */}
|
||||
<button
|
||||
tabIndex={0}
|
||||
role="button"
|
||||
className="btn m-1"
|
||||
disabled={isCurrConvGenerating}
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="16"
|
||||
height="16"
|
||||
fill="currentColor"
|
||||
className="bi bi-three-dots-vertical"
|
||||
viewBox="0 0 16 16"
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="16"
|
||||
height="16"
|
||||
fill="currentColor"
|
||||
className="bi bi-three-dots-vertical"
|
||||
viewBox="0 0 16 16"
|
||||
>
|
||||
<path d="M9.5 13a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0m0-5a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0m0-5a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0" />
|
||||
</svg>
|
||||
</button>
|
||||
{/* dropdown menu */}
|
||||
<ul
|
||||
tabIndex={0}
|
||||
className="dropdown-content menu bg-base-100 rounded-box z-[1] w-52 p-2 shadow"
|
||||
>
|
||||
<li onClick={downloadConversation}>
|
||||
<a>Download</a>
|
||||
</li>
|
||||
<li className="text-error" onClick={removeConversation}>
|
||||
<a>Delete</a>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<path d="M9.5 13a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0m0-5a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0m0-5a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0" />
|
||||
</svg>
|
||||
</button>
|
||||
{/* dropdown menu */}
|
||||
<ul
|
||||
tabIndex={0}
|
||||
className="dropdown-content menu bg-base-100 rounded-box z-[1] w-52 p-2 shadow"
|
||||
>
|
||||
<li onClick={downloadConversation}>
|
||||
<a>Download</a>
|
||||
</li>
|
||||
<li className="text-error" onClick={removeConversation}>
|
||||
<a>Delete</a>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div className="tooltip tooltip-bottom" data-tip="Settings">
|
||||
<button className="btn" onClick={() => setShowSettings(true)}>
|
||||
{/* settings button */}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useEffect, useMemo, useState } from 'react';
|
||||
import { classNames } from '../utils/misc';
|
||||
import { Conversation } from '../utils/types';
|
||||
import StorageUtils from '../utils/storage';
|
||||
@@ -7,17 +7,16 @@ import { useNavigate, useParams } from 'react-router';
|
||||
export default function Sidebar() {
|
||||
const params = useParams();
|
||||
const navigate = useNavigate();
|
||||
const currConv = useMemo(
|
||||
() => StorageUtils.getOneConversation(params.convId ?? ''),
|
||||
[params.convId]
|
||||
);
|
||||
|
||||
const [conversations, setConversations] = useState<Conversation[]>([]);
|
||||
const [currConv, setCurrConv] = useState<Conversation | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
StorageUtils.getOneConversation(params.convId ?? '').then(setCurrConv);
|
||||
}, [params.convId]);
|
||||
|
||||
useEffect(() => {
|
||||
const handleConversationChange = async () => {
|
||||
setConversations(await StorageUtils.getAllConversations());
|
||||
const handleConversationChange = () => {
|
||||
setConversations(StorageUtils.getAllConversations());
|
||||
};
|
||||
StorageUtils.onConversationChanged(handleConversationChange);
|
||||
handleConversationChange();
|
||||
@@ -83,11 +82,11 @@ export default function Sidebar() {
|
||||
onClick={() => navigate(`/chat/${conv.id}`)}
|
||||
dir="auto"
|
||||
>
|
||||
<span className="truncate">{conv.name}</span>
|
||||
<span className="truncate">{conv.messages[0].content}</span>
|
||||
</div>
|
||||
))}
|
||||
<div className="text-center text-xs opacity-40 mt-auto mx-4">
|
||||
Conversations are saved to browser's IndexedDB
|
||||
Conversations are saved to browser's localStorage
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -5,7 +5,6 @@ import {
|
||||
Conversation,
|
||||
Message,
|
||||
PendingMessage,
|
||||
ViewingChat,
|
||||
} from './types';
|
||||
import StorageUtils from './storage';
|
||||
import {
|
||||
@@ -14,25 +13,24 @@ import {
|
||||
getSSEStreamAsync,
|
||||
} from './misc';
|
||||
import { BASE_URL, CONFIG_DEFAULT, isDev } from '../Config';
|
||||
import { matchPath, useLocation, useNavigate } from 'react-router';
|
||||
import { matchPath, useLocation } from 'react-router';
|
||||
|
||||
interface AppContextValue {
|
||||
// conversations and messages
|
||||
viewingChat: ViewingChat | null;
|
||||
viewingConversation: Conversation | null;
|
||||
pendingMessages: Record<Conversation['id'], PendingMessage>;
|
||||
isGenerating: (convId: string) => boolean;
|
||||
sendMessage: (
|
||||
convId: string | null,
|
||||
leafNodeId: Message['id'] | null,
|
||||
convId: string,
|
||||
content: string,
|
||||
onChunk: CallbackGeneratedChunk
|
||||
onChunk?: CallbackGeneratedChunk
|
||||
) => Promise<boolean>;
|
||||
stopGenerating: (convId: string) => void;
|
||||
replaceMessageAndGenerate: (
|
||||
convId: string,
|
||||
parentNodeId: Message['id'], // the parent node of the message to be replaced
|
||||
content: string | null,
|
||||
onChunk: CallbackGeneratedChunk
|
||||
origMsgId: Message['id'],
|
||||
content?: string,
|
||||
onChunk?: CallbackGeneratedChunk
|
||||
) => Promise<void>;
|
||||
|
||||
// canvas
|
||||
@@ -46,33 +44,23 @@ interface AppContextValue {
|
||||
setShowSettings: (show: boolean) => void;
|
||||
}
|
||||
|
||||
// this callback is used for scrolling to the bottom of the chat and switching to the last node
|
||||
export type CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => void;
|
||||
// for now, this callback is only used for scrolling to the bottom of the chat
|
||||
type CallbackGeneratedChunk = () => void;
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const AppContext = createContext<AppContextValue>({} as any);
|
||||
|
||||
const getViewingChat = async (convId: string): Promise<ViewingChat | null> => {
|
||||
const conv = await StorageUtils.getOneConversation(convId);
|
||||
if (!conv) return null;
|
||||
return {
|
||||
conv: conv,
|
||||
// all messages from all branches, not filtered by last node
|
||||
messages: await StorageUtils.getMessages(convId),
|
||||
};
|
||||
};
|
||||
|
||||
export const AppContextProvider = ({
|
||||
children,
|
||||
}: {
|
||||
children: React.ReactElement;
|
||||
}) => {
|
||||
const { pathname } = useLocation();
|
||||
const navigate = useNavigate();
|
||||
const params = matchPath('/chat/:convId', pathname);
|
||||
const convId = params?.params?.convId;
|
||||
|
||||
const [viewingChat, setViewingChat] = useState<ViewingChat | null>(null);
|
||||
const [viewingConversation, setViewingConversation] =
|
||||
useState<Conversation | null>(null);
|
||||
const [pendingMessages, setPendingMessages] = useState<
|
||||
Record<Conversation['id'], PendingMessage>
|
||||
>({});
|
||||
@@ -87,12 +75,12 @@ export const AppContextProvider = ({
|
||||
useEffect(() => {
|
||||
// also reset the canvas data
|
||||
setCanvasData(null);
|
||||
const handleConversationChange = async (changedConvId: string) => {
|
||||
const handleConversationChange = (changedConvId: string) => {
|
||||
if (changedConvId !== convId) return;
|
||||
setViewingChat(await getViewingChat(changedConvId));
|
||||
setViewingConversation(StorageUtils.getOneConversation(convId));
|
||||
};
|
||||
StorageUtils.onConversationChanged(handleConversationChange);
|
||||
getViewingChat(convId ?? '').then(setViewingChat);
|
||||
setViewingConversation(StorageUtils.getOneConversation(convId ?? ''));
|
||||
return () => {
|
||||
StorageUtils.offConversationChanged(handleConversationChange);
|
||||
};
|
||||
@@ -130,39 +118,23 @@ export const AppContextProvider = ({
|
||||
|
||||
const generateMessage = async (
|
||||
convId: string,
|
||||
leafNodeId: Message['id'],
|
||||
onChunk: CallbackGeneratedChunk
|
||||
onChunk?: CallbackGeneratedChunk
|
||||
) => {
|
||||
if (isGenerating(convId)) return;
|
||||
|
||||
const config = StorageUtils.getConfig();
|
||||
const currConversation = await StorageUtils.getOneConversation(convId);
|
||||
const currConversation = StorageUtils.getOneConversation(convId);
|
||||
if (!currConversation) {
|
||||
throw new Error('Current conversation is not found');
|
||||
}
|
||||
|
||||
const currMessages = StorageUtils.filterByLeafNodeId(
|
||||
await StorageUtils.getMessages(convId),
|
||||
leafNodeId,
|
||||
false
|
||||
);
|
||||
const abortController = new AbortController();
|
||||
setAbort(convId, abortController);
|
||||
|
||||
if (!currMessages) {
|
||||
throw new Error('Current messages are not found');
|
||||
}
|
||||
|
||||
const pendingId = Date.now() + 1;
|
||||
let pendingMsg: PendingMessage = {
|
||||
id: pendingId,
|
||||
convId,
|
||||
type: 'text',
|
||||
timestamp: pendingId,
|
||||
id: Date.now() + 1,
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
parent: leafNodeId,
|
||||
children: [],
|
||||
};
|
||||
setPending(convId, pendingMsg);
|
||||
|
||||
@@ -172,7 +144,7 @@ export const AppContextProvider = ({
|
||||
...(config.systemMessage.length === 0
|
||||
? []
|
||||
: [{ role: 'system', content: config.systemMessage } as APIMessage]),
|
||||
...normalizeMsgsForAPI(currMessages),
|
||||
...normalizeMsgsForAPI(currConversation?.messages ?? []),
|
||||
];
|
||||
if (config.excludeThoughtOnReq) {
|
||||
messages = filterThoughtFromMsgs(messages);
|
||||
@@ -233,7 +205,8 @@ export const AppContextProvider = ({
|
||||
const lastContent = pendingMsg.content || '';
|
||||
if (addedContent) {
|
||||
pendingMsg = {
|
||||
...pendingMsg,
|
||||
id: pendingMsg.id,
|
||||
role: 'assistant',
|
||||
content: lastContent + addedContent,
|
||||
};
|
||||
}
|
||||
@@ -248,7 +221,7 @@ export const AppContextProvider = ({
|
||||
};
|
||||
}
|
||||
setPending(convId, pendingMsg);
|
||||
onChunk(); // don't need to switch node for pending message
|
||||
onChunk?.();
|
||||
}
|
||||
} catch (err) {
|
||||
setPending(convId, null);
|
||||
@@ -263,53 +236,37 @@ export const AppContextProvider = ({
|
||||
}
|
||||
}
|
||||
|
||||
if (pendingMsg.content !== null) {
|
||||
await StorageUtils.appendMsg(pendingMsg as Message, leafNodeId);
|
||||
if (pendingMsg.content) {
|
||||
StorageUtils.appendMsg(currConversation.id, {
|
||||
id: pendingMsg.id,
|
||||
content: pendingMsg.content,
|
||||
role: pendingMsg.role,
|
||||
timings: pendingMsg.timings,
|
||||
});
|
||||
}
|
||||
setPending(convId, null);
|
||||
onChunk(pendingId); // trigger scroll to bottom and switch to the last node
|
||||
onChunk?.(); // trigger scroll to bottom
|
||||
};
|
||||
|
||||
const sendMessage = async (
|
||||
convId: string | null,
|
||||
leafNodeId: Message['id'] | null,
|
||||
convId: string,
|
||||
content: string,
|
||||
onChunk: CallbackGeneratedChunk
|
||||
onChunk?: CallbackGeneratedChunk
|
||||
): Promise<boolean> => {
|
||||
if (isGenerating(convId ?? '') || content.trim().length === 0) return false;
|
||||
if (isGenerating(convId) || content.trim().length === 0) return false;
|
||||
|
||||
if (convId === null || convId.length === 0 || leafNodeId === null) {
|
||||
const conv = await StorageUtils.createConversation(
|
||||
content.substring(0, 256)
|
||||
);
|
||||
convId = conv.id;
|
||||
leafNodeId = conv.currNode;
|
||||
// if user is creating a new conversation, redirect to the new conversation
|
||||
navigate(`/chat/${convId}`);
|
||||
}
|
||||
|
||||
const now = Date.now();
|
||||
const currMsgId = now;
|
||||
StorageUtils.appendMsg(
|
||||
{
|
||||
id: currMsgId,
|
||||
timestamp: now,
|
||||
type: 'text',
|
||||
convId,
|
||||
role: 'user',
|
||||
content,
|
||||
parent: leafNodeId,
|
||||
children: [],
|
||||
},
|
||||
leafNodeId
|
||||
);
|
||||
onChunk(currMsgId);
|
||||
StorageUtils.appendMsg(convId, {
|
||||
id: Date.now(),
|
||||
role: 'user',
|
||||
content,
|
||||
});
|
||||
|
||||
try {
|
||||
await generateMessage(convId, currMsgId, onChunk);
|
||||
await generateMessage(convId, onChunk);
|
||||
return true;
|
||||
} catch (_) {
|
||||
// TODO: rollback
|
||||
// rollback
|
||||
StorageUtils.popMsg(convId);
|
||||
}
|
||||
return false;
|
||||
};
|
||||
@@ -322,33 +279,22 @@ export const AppContextProvider = ({
|
||||
// if content is undefined, we remove last assistant message
|
||||
const replaceMessageAndGenerate = async (
|
||||
convId: string,
|
||||
parentNodeId: Message['id'], // the parent node of the message to be replaced
|
||||
content: string | null,
|
||||
onChunk: CallbackGeneratedChunk
|
||||
origMsgId: Message['id'],
|
||||
content?: string,
|
||||
onChunk?: CallbackGeneratedChunk
|
||||
) => {
|
||||
if (isGenerating(convId)) return;
|
||||
|
||||
if (content !== null) {
|
||||
const now = Date.now();
|
||||
const currMsgId = now;
|
||||
StorageUtils.appendMsg(
|
||||
{
|
||||
id: currMsgId,
|
||||
timestamp: now,
|
||||
type: 'text',
|
||||
convId,
|
||||
role: 'user',
|
||||
content,
|
||||
parent: parentNodeId,
|
||||
children: [],
|
||||
},
|
||||
parentNodeId
|
||||
);
|
||||
parentNodeId = currMsgId;
|
||||
StorageUtils.filterAndKeepMsgs(convId, (msg) => msg.id < origMsgId);
|
||||
if (content) {
|
||||
StorageUtils.appendMsg(convId, {
|
||||
id: Date.now(),
|
||||
role: 'user',
|
||||
content,
|
||||
});
|
||||
}
|
||||
onChunk(parentNodeId);
|
||||
|
||||
await generateMessage(convId, parentNodeId, onChunk);
|
||||
await generateMessage(convId, onChunk);
|
||||
};
|
||||
|
||||
const saveConfig = (config: typeof CONFIG_DEFAULT) => {
|
||||
@@ -360,7 +306,7 @@ export const AppContextProvider = ({
|
||||
<AppContext.Provider
|
||||
value={{
|
||||
isGenerating,
|
||||
viewingChat,
|
||||
viewingConversation,
|
||||
pendingMessages,
|
||||
sendMessage,
|
||||
stopGenerating,
|
||||
|
||||
@@ -4,6 +4,7 @@ import { APIMessage, Message } from './types';
|
||||
|
||||
// ponyfill for missing ReadableStream asyncIterator on Safari
|
||||
import { asyncIterator } from '@sec-ant/readable-stream/ponyfill/asyncIterator';
|
||||
import { isDev } from '../Config';
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
export const isString = (x: any) => !!x.toLowerCase;
|
||||
@@ -22,7 +23,7 @@ export async function* getSSEStreamAsync(fetchResponse: Response) {
|
||||
.pipeThrough(new TextLineStream());
|
||||
// @ts-expect-error asyncIterator complains about type, but it should work
|
||||
for await (const line of asyncIterator(lines)) {
|
||||
//if (isDev) console.log({ line });
|
||||
if (isDev) console.log({ line });
|
||||
if (line.startsWith('data:') && !line.endsWith('[DONE]')) {
|
||||
const data = JSON.parse(line.slice(5));
|
||||
yield data;
|
||||
@@ -54,7 +55,7 @@ export const copyStr = (textToCopy: string) => {
|
||||
/**
|
||||
* filter out redundant fields upon sending to API
|
||||
*/
|
||||
export function normalizeMsgsForAPI(messages: Readonly<Message[]>) {
|
||||
export function normalizeMsgsForAPI(messages: Message[]) {
|
||||
return messages.map((msg) => {
|
||||
return {
|
||||
role: msg.role,
|
||||
@@ -87,23 +88,3 @@ export function classNames(classes: Record<string, boolean>): string {
|
||||
|
||||
export const delay = (ms: number) =>
|
||||
new Promise((resolve) => setTimeout(resolve, ms));
|
||||
|
||||
export const throttle = <T extends unknown[]>(
|
||||
callback: (...args: T) => void,
|
||||
delay: number
|
||||
) => {
|
||||
let isWaiting = false;
|
||||
|
||||
return (...args: T) => {
|
||||
if (isWaiting) {
|
||||
return;
|
||||
}
|
||||
|
||||
callback(...args);
|
||||
isWaiting = true;
|
||||
|
||||
setTimeout(() => {
|
||||
isWaiting = false;
|
||||
}, delay);
|
||||
};
|
||||
};
|
||||
|
||||
@@ -2,8 +2,7 @@
|
||||
// format: { [convId]: { id: string, lastModified: number, messages: [...] } }
|
||||
|
||||
import { CONFIG_DEFAULT } from '../Config';
|
||||
import { Conversation, Message, TimingReport } from './types';
|
||||
import Dexie, { Table } from 'dexie';
|
||||
import { Conversation, Message } from './types';
|
||||
|
||||
const event = new EventTarget();
|
||||
|
||||
@@ -18,154 +17,85 @@ const dispatchConversationChange = (convId: string) => {
|
||||
);
|
||||
};
|
||||
|
||||
const db = new Dexie('LlamacppWebui') as Dexie & {
|
||||
conversations: Table<Conversation>;
|
||||
messages: Table<Message>;
|
||||
};
|
||||
|
||||
// https://dexie.org/docs/Version/Version.stores()
|
||||
db.version(1).stores({
|
||||
// Unlike SQL, you don’t need to specify all properties but only the one you wish to index.
|
||||
conversations: '&id, lastModified',
|
||||
messages: '&id, convId, [convId+id], timestamp',
|
||||
});
|
||||
|
||||
// convId is a string prefixed with 'conv-'
|
||||
const StorageUtils = {
|
||||
/**
|
||||
* manage conversations
|
||||
*/
|
||||
async getAllConversations(): Promise<Conversation[]> {
|
||||
await migrationLStoIDB().catch(console.error); // noop if already migrated
|
||||
return (await db.conversations.toArray()).sort(
|
||||
(a, b) => b.lastModified - a.lastModified
|
||||
);
|
||||
getAllConversations(): Conversation[] {
|
||||
const res = [];
|
||||
for (const key in localStorage) {
|
||||
if (key.startsWith('conv-')) {
|
||||
res.push(JSON.parse(localStorage.getItem(key) ?? '{}'));
|
||||
}
|
||||
}
|
||||
res.sort((a, b) => b.lastModified - a.lastModified);
|
||||
return res;
|
||||
},
|
||||
/**
|
||||
* can return null if convId does not exist
|
||||
*/
|
||||
async getOneConversation(convId: string): Promise<Conversation | null> {
|
||||
return (await db.conversations.where('id').equals(convId).first()) ?? null;
|
||||
getOneConversation(convId: string): Conversation | null {
|
||||
return JSON.parse(localStorage.getItem(convId) || 'null');
|
||||
},
|
||||
/**
|
||||
* get all message nodes in a conversation
|
||||
* if convId does not exist, create one
|
||||
*/
|
||||
async getMessages(convId: string): Promise<Message[]> {
|
||||
return await db.messages.where({ convId }).toArray();
|
||||
},
|
||||
/**
|
||||
* use in conjunction with getMessages to filter messages by leafNodeId
|
||||
* includeRoot: whether to include the root node in the result
|
||||
* if node with leafNodeId does not exist, return the path with the latest timestamp
|
||||
*/
|
||||
filterByLeafNodeId(
|
||||
msgs: Readonly<Message[]>,
|
||||
leafNodeId: Message['id'],
|
||||
includeRoot: boolean
|
||||
): Readonly<Message[]> {
|
||||
const res: Message[] = [];
|
||||
const nodeMap = new Map<Message['id'], Message>();
|
||||
for (const msg of msgs) {
|
||||
nodeMap.set(msg.id, msg);
|
||||
}
|
||||
let startNode: Message | undefined = nodeMap.get(leafNodeId);
|
||||
if (!startNode) {
|
||||
// if not found, we return the path with the latest timestamp
|
||||
let latestTime = -1;
|
||||
for (const msg of msgs) {
|
||||
if (msg.timestamp > latestTime) {
|
||||
startNode = msg;
|
||||
latestTime = msg.timestamp;
|
||||
}
|
||||
}
|
||||
}
|
||||
// traverse the path from leafNodeId to root
|
||||
// startNode can never be undefined here
|
||||
let currNode: Message | undefined = startNode;
|
||||
while (currNode) {
|
||||
if (currNode.type !== 'root' || (currNode.type === 'root' && includeRoot))
|
||||
res.push(currNode);
|
||||
currNode = nodeMap.get(currNode.parent ?? -1);
|
||||
}
|
||||
res.sort((a, b) => a.timestamp - b.timestamp);
|
||||
return res;
|
||||
},
|
||||
/**
|
||||
* create a new conversation with a default root node
|
||||
*/
|
||||
async createConversation(name: string): Promise<Conversation> {
|
||||
const now = Date.now();
|
||||
const msgId = now;
|
||||
const conv: Conversation = {
|
||||
id: `conv-${now}`,
|
||||
lastModified: now,
|
||||
currNode: msgId,
|
||||
name,
|
||||
};
|
||||
await db.conversations.add(conv);
|
||||
// create a root node
|
||||
await db.messages.add({
|
||||
id: msgId,
|
||||
convId: conv.id,
|
||||
type: 'root',
|
||||
timestamp: now,
|
||||
role: 'system',
|
||||
content: '',
|
||||
parent: -1,
|
||||
children: [],
|
||||
});
|
||||
return conv;
|
||||
},
|
||||
/**
|
||||
* if convId does not exist, throw an error
|
||||
*/
|
||||
async appendMsg(
|
||||
msg: Exclude<Message, 'parent' | 'children'>,
|
||||
parentNodeId: Message['id']
|
||||
): Promise<void> {
|
||||
appendMsg(convId: string, msg: Message): void {
|
||||
if (msg.content === null) return;
|
||||
const { convId } = msg;
|
||||
await db.transaction('rw', db.conversations, db.messages, async () => {
|
||||
const conv = await StorageUtils.getOneConversation(convId);
|
||||
const parentMsg = await db.messages
|
||||
.where({ convId, id: parentNodeId })
|
||||
.first();
|
||||
// update the currNode of conversation
|
||||
if (!conv) {
|
||||
throw new Error(`Conversation ${convId} does not exist`);
|
||||
}
|
||||
if (!parentMsg) {
|
||||
throw new Error(
|
||||
`Parent message ID ${parentNodeId} does not exist in conversation ${convId}`
|
||||
);
|
||||
}
|
||||
await db.conversations.update(convId, {
|
||||
lastModified: Date.now(),
|
||||
currNode: msg.id,
|
||||
});
|
||||
// update parent
|
||||
await db.messages.update(parentNodeId, {
|
||||
children: [...parentMsg.children, msg.id],
|
||||
});
|
||||
// create message
|
||||
await db.messages.add({
|
||||
...msg,
|
||||
parent: parentNodeId,
|
||||
children: [],
|
||||
});
|
||||
});
|
||||
const conv = StorageUtils.getOneConversation(convId) || {
|
||||
id: convId,
|
||||
lastModified: Date.now(),
|
||||
messages: [],
|
||||
};
|
||||
conv.messages.push(msg);
|
||||
conv.lastModified = Date.now();
|
||||
localStorage.setItem(convId, JSON.stringify(conv));
|
||||
dispatchConversationChange(convId);
|
||||
},
|
||||
/**
|
||||
* Get new conversation id
|
||||
*/
|
||||
getNewConvId(): string {
|
||||
return `conv-${Date.now()}`;
|
||||
},
|
||||
/**
|
||||
* remove conversation by id
|
||||
*/
|
||||
async remove(convId: string): Promise<void> {
|
||||
await db.transaction('rw', db.conversations, db.messages, async () => {
|
||||
await db.conversations.delete(convId);
|
||||
await db.messages.where({ convId }).delete();
|
||||
});
|
||||
remove(convId: string): void {
|
||||
localStorage.removeItem(convId);
|
||||
dispatchConversationChange(convId);
|
||||
},
|
||||
/**
|
||||
* remove all conversations
|
||||
*/
|
||||
filterAndKeepMsgs(
|
||||
convId: string,
|
||||
predicate: (msg: Message) => boolean
|
||||
): void {
|
||||
const conv = StorageUtils.getOneConversation(convId);
|
||||
if (!conv) return;
|
||||
conv.messages = conv.messages.filter(predicate);
|
||||
conv.lastModified = Date.now();
|
||||
localStorage.setItem(convId, JSON.stringify(conv));
|
||||
dispatchConversationChange(convId);
|
||||
},
|
||||
/**
|
||||
* remove last message from conversation
|
||||
*/
|
||||
popMsg(convId: string): Message | undefined {
|
||||
const conv = StorageUtils.getOneConversation(convId);
|
||||
if (!conv) return;
|
||||
const msg = conv.messages.pop();
|
||||
conv.lastModified = Date.now();
|
||||
if (conv.messages.length === 0) {
|
||||
StorageUtils.remove(convId);
|
||||
} else {
|
||||
localStorage.setItem(convId, JSON.stringify(conv));
|
||||
}
|
||||
dispatchConversationChange(convId);
|
||||
return msg;
|
||||
},
|
||||
|
||||
// event listeners
|
||||
onConversationChanged(callback: CallbackConversationChanged) {
|
||||
@@ -206,79 +136,3 @@ const StorageUtils = {
|
||||
};
|
||||
|
||||
export default StorageUtils;
|
||||
|
||||
// Migration from localStorage to IndexedDB
|
||||
|
||||
// these are old types, LS prefix stands for LocalStorage
|
||||
interface LSConversation {
|
||||
id: string; // format: `conv-{timestamp}`
|
||||
lastModified: number; // timestamp from Date.now()
|
||||
messages: LSMessage[];
|
||||
}
|
||||
interface LSMessage {
|
||||
id: number;
|
||||
role: 'user' | 'assistant' | 'system';
|
||||
content: string;
|
||||
timings?: TimingReport;
|
||||
}
|
||||
async function migrationLStoIDB() {
|
||||
if (localStorage.getItem('migratedToIDB')) return;
|
||||
const res: LSConversation[] = [];
|
||||
for (const key in localStorage) {
|
||||
if (key.startsWith('conv-')) {
|
||||
res.push(JSON.parse(localStorage.getItem(key) ?? '{}'));
|
||||
}
|
||||
}
|
||||
if (res.length === 0) return;
|
||||
await db.transaction('rw', db.conversations, db.messages, async () => {
|
||||
let migratedCount = 0;
|
||||
for (const conv of res) {
|
||||
const { id: convId, lastModified, messages } = conv;
|
||||
const firstMsg = messages[0];
|
||||
const lastMsg = messages.at(-1);
|
||||
if (messages.length < 2 || !firstMsg || !lastMsg) {
|
||||
console.log(
|
||||
`Skipping conversation ${convId} with ${messages.length} messages`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
const name = firstMsg.content ?? '(no messages)';
|
||||
await db.conversations.add({
|
||||
id: convId,
|
||||
lastModified,
|
||||
currNode: lastMsg.id,
|
||||
name,
|
||||
});
|
||||
const rootId = messages[0].id - 2;
|
||||
await db.messages.add({
|
||||
id: rootId,
|
||||
convId: convId,
|
||||
type: 'root',
|
||||
timestamp: rootId,
|
||||
role: 'system',
|
||||
content: '',
|
||||
parent: -1,
|
||||
children: [firstMsg.id],
|
||||
});
|
||||
for (let i = 0; i < messages.length; i++) {
|
||||
const msg = messages[i];
|
||||
await db.messages.add({
|
||||
...msg,
|
||||
type: 'text',
|
||||
convId: convId,
|
||||
timestamp: msg.id,
|
||||
parent: i === 0 ? rootId : messages[i - 1].id,
|
||||
children: i === messages.length - 1 ? [] : [messages[i + 1].id],
|
||||
});
|
||||
}
|
||||
migratedCount++;
|
||||
console.log(
|
||||
`Migrated conversation ${convId} with ${messages.length} messages`
|
||||
);
|
||||
}
|
||||
console.log(
|
||||
`Migrated ${migratedCount} conversations from localStorage to IndexedDB`
|
||||
);
|
||||
localStorage.setItem('migratedToIDB', '1');
|
||||
});
|
||||
}
|
||||
|
||||
@@ -5,46 +5,11 @@ export interface TimingReport {
|
||||
predicted_ms: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* What is conversation "branching"? It is a feature that allows the user to edit an old message in the history, while still keeping the conversation flow.
|
||||
* Inspired by ChatGPT / Claude / Hugging Chat where you edit a message, a new branch of the conversation is created, and the old message is still visible.
|
||||
*
|
||||
* We use the same node-based structure like other chat UIs, where each message has a parent and children. A "root" message is the first message in a conversation, which will not be displayed in the UI.
|
||||
*
|
||||
* root
|
||||
* ├── message 1
|
||||
* │ └── message 2
|
||||
* │ └── message 3
|
||||
* └── message 4
|
||||
* └── message 5
|
||||
*
|
||||
* In the above example, assuming that user wants to edit message 2, a new branch will be created:
|
||||
*
|
||||
* ├── message 2
|
||||
* │ └── message 3
|
||||
* └── message 6
|
||||
*
|
||||
* Message 2 and 6 are siblings, and message 6 is the new branch.
|
||||
*
|
||||
* We only need to know the last node (aka leaf) to get the current branch. In the above example, message 5 is the leaf of branch containing message 4 and 5.
|
||||
*
|
||||
* For the implementation:
|
||||
* - StorageUtils.getMessages() returns list of all nodes
|
||||
* - StorageUtils.filterByLeafNodeId() filters the list of nodes from a given leaf node
|
||||
*/
|
||||
|
||||
// Note: the term "message" and "node" are used interchangeably in this context
|
||||
export interface Message {
|
||||
id: number;
|
||||
convId: string;
|
||||
type: 'text' | 'root';
|
||||
timestamp: number; // timestamp from Date.now()
|
||||
role: 'user' | 'assistant' | 'system';
|
||||
content: string;
|
||||
timings?: TimingReport;
|
||||
// node based system for branching
|
||||
parent: Message['id'];
|
||||
children: Message['id'][];
|
||||
}
|
||||
|
||||
export type APIMessage = Pick<Message, 'role' | 'content'>;
|
||||
@@ -52,13 +17,7 @@ export type APIMessage = Pick<Message, 'role' | 'content'>;
|
||||
export interface Conversation {
|
||||
id: string; // format: `conv-{timestamp}`
|
||||
lastModified: number; // timestamp from Date.now()
|
||||
currNode: Message['id']; // the current message node being viewed
|
||||
name: string;
|
||||
}
|
||||
|
||||
export interface ViewingChat {
|
||||
conv: Readonly<Conversation>;
|
||||
messages: Readonly<Message[]>;
|
||||
messages: Message[];
|
||||
}
|
||||
|
||||
export type PendingMessage = Omit<Message, 'content'> & {
|
||||
|
||||
@@ -10,6 +10,8 @@ extern "C" {
|
||||
#define GGML_VK_NAME "Vulkan"
|
||||
#define GGML_VK_MAX_DEVICES 16
|
||||
|
||||
GGML_BACKEND_API void ggml_vk_instance_init(void);
|
||||
|
||||
// backend API
|
||||
GGML_BACKEND_API ggml_backend_t ggml_backend_vk_init(size_t dev_num);
|
||||
|
||||
|
||||
@@ -473,6 +473,7 @@ GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128)
|
||||
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
||||
GGML_TABLE_END()
|
||||
|
||||
//#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A // lowest compute capability for integer intrinsics
|
||||
GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
|
||||
0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
|
||||
0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
|
||||
@@ -507,6 +508,7 @@ GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
|
||||
0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
|
||||
0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
|
||||
GGML_TABLE_END()
|
||||
//#endif
|
||||
|
||||
|
||||
GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256)
|
||||
|
||||
@@ -284,14 +284,14 @@ struct ggml_backend_cpu_device_context {
|
||||
&hKey) == ERROR_SUCCESS) {
|
||||
DWORD cpu_brand_size = 0;
|
||||
if (RegQueryValueExA(hKey,
|
||||
"ProcessorNameString",
|
||||
TEXT("ProcessorNameString"),
|
||||
NULL,
|
||||
NULL,
|
||||
NULL,
|
||||
&cpu_brand_size) == ERROR_SUCCESS) {
|
||||
description.resize(cpu_brand_size);
|
||||
if (RegQueryValueExA(hKey,
|
||||
"ProcessorNameString",
|
||||
TEXT("ProcessorNameString"),
|
||||
NULL,
|
||||
NULL,
|
||||
(LPBYTE)&description[0], // NOLINT
|
||||
|
||||
@@ -71,47 +71,6 @@
|
||||
#define GGML_CUDA_CC_QY1 210
|
||||
#define GGML_CUDA_CC_QY2 220
|
||||
|
||||
#ifdef __CUDA_ARCH_LIST__
|
||||
constexpr bool ggml_cuda_has_arch_impl(int) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template<class ... Archs>
|
||||
constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) {
|
||||
return arch == first || ggml_cuda_has_arch_impl(arch, rest...);
|
||||
}
|
||||
|
||||
constexpr bool ggml_cuda_has_arch(const int arch) {
|
||||
return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
|
||||
}
|
||||
|
||||
constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) {
|
||||
if (cur == 0) {
|
||||
GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch);
|
||||
}
|
||||
return cur;
|
||||
}
|
||||
|
||||
template<class ... Archs>
|
||||
constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) {
|
||||
if (first <= arch && first > cur) {
|
||||
return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...);
|
||||
} else {
|
||||
return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...);
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int ggml_cuda_highest_compiled_arch(const int arch) {
|
||||
return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__);
|
||||
}
|
||||
#else
|
||||
static int ggml_cuda_highest_compiled_arch(const int arch) {
|
||||
return arch;
|
||||
}
|
||||
#endif // __CUDA_ARCH_LIST__
|
||||
|
||||
// ---------------------------------------------------------------------------------------------------------
|
||||
|
||||
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
@@ -203,32 +162,18 @@ typedef float2 dfloat2;
|
||||
#define FLASH_ATTN_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
|
||||
static bool fp16_available(const int cc) {
|
||||
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
|
||||
}
|
||||
|
||||
static bool fast_fp16_available(const int cc) {
|
||||
return fp16_available(cc) && cc != 610;
|
||||
}
|
||||
|
||||
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
||||
static bool fast_fp16_hardware_available(const int cc) {
|
||||
static constexpr bool fast_fp16_available(const int cc) {
|
||||
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
|
||||
}
|
||||
|
||||
// Any FP16 tensor core instructions are available for ggml code.
|
||||
static bool fp16_mma_available(const int cc) {
|
||||
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
|
||||
}
|
||||
|
||||
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
||||
static bool fp16_mma_hardware_available(const int cc) {
|
||||
// Any FP16 tensor cores are available.
|
||||
static constexpr bool fp16_mma_available(const int cc) {
|
||||
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
|
||||
}
|
||||
|
||||
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
||||
static bool new_mma_available(const int cc) {
|
||||
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
||||
static constexpr bool new_mma_available(const int cc) {
|
||||
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
|
||||
}
|
||||
|
||||
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
||||
|
||||
@@ -599,7 +599,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
||||
case GGML_TYPE_Q5_1:
|
||||
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
||||
case GGML_TYPE_Q8_0:
|
||||
if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
|
||||
if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_PASCAL) {
|
||||
return dequantize_block_q8_0_f16_cuda;
|
||||
}
|
||||
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
||||
|
||||
@@ -1867,14 +1867,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
||||
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
|
||||
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
|
||||
}
|
||||
} else {
|
||||
const int cc = ggml_cuda_info().devices[ctx.device].cc;
|
||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
||||
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
|
||||
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
|
||||
}
|
||||
|
||||
// debug helpers
|
||||
@@ -3205,8 +3205,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
|
||||
return true;
|
||||
}
|
||||
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
|
||||
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
||||
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
|
||||
return cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
||||
}
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
|
||||
@@ -18,7 +18,7 @@ void ggml_cuda_op_mul_mat_q(
|
||||
const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
|
||||
|
||||
int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
||||
|
||||
// the main device has a larger memory buffer to hold the results from all GPUs
|
||||
// nrows_dst == nrows of the matrix that the kernel writes into
|
||||
@@ -27,8 +27,7 @@ void ggml_cuda_op_mul_mat_q(
|
||||
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
|
||||
// Also its fixup needs to allocate a temporary buffer in the memory pool.
|
||||
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
|
||||
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA &&
|
||||
cc < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
|
||||
const bool use_stream_k = compute_capability >= GGML_CUDA_CC_VOLTA && compute_capability < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
|
||||
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
|
||||
|
||||
switch (src0->type) {
|
||||
@@ -137,7 +136,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) {
|
||||
if (cc < GGML_CUDA_CC_DP4A) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -146,7 +145,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
#endif //GGML_CUDA_FORCE_MMQ
|
||||
|
||||
if (cc < GGML_CUDA_CC_OFFSET_AMD) {
|
||||
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc) && !GGML_CUDA_CC_IS_GCN(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
|
||||
@@ -86,13 +86,12 @@ struct tile_x_sizes {
|
||||
int sc;
|
||||
};
|
||||
|
||||
static int get_mmq_x_max_host(const int cc) {
|
||||
static constexpr int get_mmq_x_max_host(const int cc) {
|
||||
return new_mma_available(cc) ? 128 :
|
||||
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ?
|
||||
#ifdef GGML_CUDA_FORCE_MMQ
|
||||
128 : 64;
|
||||
cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128 : 64;
|
||||
#else
|
||||
MMQ_DP4A_MAX_BATCH_SIZE : 64;
|
||||
cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
|
||||
#endif // GGML_CUDA_FORCE_MMQ
|
||||
}
|
||||
|
||||
@@ -120,9 +119,8 @@ static constexpr __device__ int get_mmq_x_max_device() {
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static int get_mmq_y_host(const int cc) {
|
||||
return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
|
||||
(ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? 128 : 64);
|
||||
static constexpr int get_mmq_y_host(const int cc) {
|
||||
return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64);
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_mmq_y_device() {
|
||||
@@ -2830,7 +2828,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
|
||||
const int mmq_x_max = get_mmq_x_max_host(cc);
|
||||
const int mmq_y = get_mmq_y_host(cc);
|
||||
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
|
||||
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
|
||||
const bool use_stream_k = cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
|
||||
|
||||
int mmq_x_best = 0;
|
||||
int nparts_best = INT_MAX;
|
||||
|
||||
@@ -149,6 +149,66 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf);
|
||||
|
||||
static constexpr uint32_t mul_mat_vec_max_cols = 8;
|
||||
|
||||
enum vk_device_architecture {
|
||||
OTHER,
|
||||
AMD_GCN,
|
||||
AMD_RDNA1,
|
||||
AMD_RDNA2,
|
||||
AMD_RDNA3,
|
||||
};
|
||||
|
||||
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
|
||||
vk::PhysicalDeviceProperties props = device.getProperties();
|
||||
|
||||
if (props.vendorID == VK_VENDOR_ID_AMD) {
|
||||
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
|
||||
|
||||
bool amd_shader_core_properties = false;
|
||||
bool integer_dot_product = false;
|
||||
bool subgroup_size_control = false;
|
||||
|
||||
for (const auto& properties : ext_props) {
|
||||
if (strcmp("VK_AMD_shader_core_properties", properties.extensionName) == 0) {
|
||||
amd_shader_core_properties = true;
|
||||
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) {
|
||||
integer_dot_product = true;
|
||||
} else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
|
||||
subgroup_size_control = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) {
|
||||
return vk_device_architecture::OTHER;
|
||||
}
|
||||
|
||||
vk::PhysicalDeviceProperties2 props2;
|
||||
vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd;
|
||||
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;
|
||||
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
|
||||
|
||||
props2.pNext = &shader_core_props_amd;
|
||||
shader_core_props_amd.pNext = &integer_dot_props;
|
||||
integer_dot_props.pNext = &subgroup_size_control_props;
|
||||
|
||||
device.getProperties2(&props2);
|
||||
|
||||
if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64) {
|
||||
return vk_device_architecture::AMD_GCN;
|
||||
}
|
||||
if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32) {
|
||||
// RDNA
|
||||
if (shader_core_props_amd.wavefrontsPerSimd == 20) {
|
||||
return vk_device_architecture::AMD_RDNA1;
|
||||
}
|
||||
if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated) {
|
||||
return vk_device_architecture::AMD_RDNA3;
|
||||
}
|
||||
return vk_device_architecture::AMD_RDNA2;
|
||||
}
|
||||
}
|
||||
return vk_device_architecture::OTHER;
|
||||
}
|
||||
|
||||
struct vk_device_struct {
|
||||
std::mutex mutex;
|
||||
|
||||
@@ -161,13 +221,13 @@ struct vk_device_struct {
|
||||
bool pipeline_robustness;
|
||||
vk::Device device;
|
||||
uint32_t vendor_id;
|
||||
vk_device_architecture architecture;
|
||||
vk_queue compute_queue;
|
||||
vk_queue transfer_queue;
|
||||
bool single_queue;
|
||||
uint32_t subgroup_size;
|
||||
uint32_t shader_core_count;
|
||||
bool uma;
|
||||
bool prefer_host_memory;
|
||||
bool float_controls_rte_fp16;
|
||||
|
||||
bool subgroup_size_control;
|
||||
@@ -1295,9 +1355,7 @@ static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk:
|
||||
static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
|
||||
vk_buffer buf;
|
||||
try {
|
||||
if (device->prefer_host_memory) {
|
||||
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
||||
} else if (device->uma) {
|
||||
if (device->uma) {
|
||||
// Fall back to host memory type
|
||||
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
|
||||
} else {
|
||||
@@ -1426,6 +1484,49 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
||||
return supported;
|
||||
}
|
||||
|
||||
struct GpuPipelineConfig {
|
||||
// List of all aliases for a given GPU.
|
||||
// For example, this can include names like "NAVI10", "RX 5700", etc.
|
||||
std::vector<std::string> device_names;
|
||||
|
||||
// Mapping of pipeline names to their specific subgroup sizes.
|
||||
// Example: {"soft_max_f32", 64}.
|
||||
std::unordered_map<std::string, uint32_t> pipelines;
|
||||
|
||||
// Default subgroup size for this GPU.
|
||||
// Defaults to 0 if not explicitly provided.
|
||||
uint32_t default_subgroup_size = 0;
|
||||
};
|
||||
|
||||
// Define configurations for different GPUs.
|
||||
static std::vector<GpuPipelineConfig> gpu_pipeline_configs = {
|
||||
{
|
||||
{"NAVI10", "NAVI14", "RX 5700", "RX 5600", "RX 5500"},
|
||||
{
|
||||
{"soft_max_f32", 64}, {"soft_max_f32_wg512", 64},
|
||||
{"soft_max_f32_f16", 64}, {"soft_max_f32_f16_wg512", 64},
|
||||
{"im2col_f32", 64}, {"im2col_f32_f16", 64},
|
||||
},
|
||||
32
|
||||
},
|
||||
};
|
||||
|
||||
static uint32_t get_subgroup_size(const std::string &pipeline_name, const std::string &device_name) {
|
||||
for (const auto &config : gpu_pipeline_configs) {
|
||||
for (const auto &alias : config.device_names) {
|
||||
if (device_name.find(alias) != std::string::npos) {
|
||||
auto pipIt = config.pipelines.find(pipeline_name);
|
||||
if (pipIt != config.pipelines.end() && pipIt->second != 0) {
|
||||
return pipIt->second;
|
||||
}
|
||||
return config.default_subgroup_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
// If no matching configuration is found, return 0.
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void ggml_vk_load_shaders(vk_device& device) {
|
||||
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
|
||||
|
||||
@@ -1546,11 +1647,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
||||
}
|
||||
|
||||
vk::PhysicalDeviceProperties2 props2;
|
||||
device->physical_device.getProperties2(&props2);
|
||||
std::string device_name = props2.properties.deviceName.data();
|
||||
|
||||
std::vector<std::future<void>> compiles;
|
||||
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
|
||||
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
|
||||
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
|
||||
|
||||
required_subgroup_size = get_subgroup_size(name, device_name);
|
||||
|
||||
if (!pipeline) {
|
||||
pipeline = std::make_shared<vk_pipeline_struct>();
|
||||
pipeline->name = name;
|
||||
@@ -2173,7 +2280,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
device->need_compiles = false;
|
||||
}
|
||||
|
||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props);
|
||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
|
||||
|
||||
static vk_device ggml_vk_get_device(size_t idx) {
|
||||
VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
|
||||
@@ -2202,8 +2309,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
device->physical_device = physical_devices[dev_num];
|
||||
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
|
||||
|
||||
const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
|
||||
device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
|
||||
device->architecture = get_device_architecture(device->physical_device);
|
||||
|
||||
bool fp16_storage = false;
|
||||
bool fp16_compute = false;
|
||||
@@ -2214,7 +2320,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
bool coopmat2_support = false;
|
||||
device->coopmat_support = false;
|
||||
|
||||
// Check if maintenance4 is supported
|
||||
for (const auto& properties : ext_props) {
|
||||
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
||||
maintenance4_support = true;
|
||||
@@ -2327,7 +2432,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
|
||||
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
||||
|
||||
if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) {
|
||||
if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props, device->architecture)) {
|
||||
device->coopmat_support = false;
|
||||
}
|
||||
|
||||
@@ -2705,7 +2810,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||
subgroup_props.pNext = &driver_props;
|
||||
physical_device.getProperties2(&props2);
|
||||
|
||||
const size_t subgroup_size = subgroup_props.subgroupSize;
|
||||
uint32_t default_subgroup_size = get_subgroup_size("", props2.properties.deviceName.data());
|
||||
const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
|
||||
|
||||
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
||||
|
||||
bool fp16_storage = false;
|
||||
@@ -2731,7 +2838,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||
}
|
||||
}
|
||||
|
||||
if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props)) {
|
||||
const vk_device_architecture device_architecture = get_device_architecture(physical_device);
|
||||
|
||||
if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture)) {
|
||||
coopmat_support = false;
|
||||
}
|
||||
|
||||
@@ -2793,12 +2902,14 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||
static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
|
||||
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
|
||||
|
||||
static void ggml_vk_instance_init() {
|
||||
void ggml_vk_instance_init() {
|
||||
if (vk_instance_initialized) {
|
||||
return;
|
||||
}
|
||||
VK_LOG_DEBUG("ggml_vk_instance_init()");
|
||||
|
||||
vk_instance_initialized = true;
|
||||
|
||||
uint32_t api_version = vk::enumerateInstanceVersion();
|
||||
|
||||
if (api_version < VK_API_VERSION_1_2) {
|
||||
@@ -2849,7 +2960,6 @@ static void ggml_vk_instance_init() {
|
||||
GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n");
|
||||
}
|
||||
vk_instance.instance = vk::createInstance(instance_create_info);
|
||||
vk_instance_initialized = true;
|
||||
|
||||
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
|
||||
|
||||
@@ -2874,7 +2984,7 @@ static void ggml_vk_instance_init() {
|
||||
// Make sure at least one device exists
|
||||
if (devices.empty()) {
|
||||
std::cerr << "ggml_vulkan: Error: No devices found." << std::endl;
|
||||
return;
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
// Default to using all dedicated GPUs
|
||||
@@ -8349,13 +8459,8 @@ ggml_backend_reg_t ggml_backend_vk_reg() {
|
||||
/* .iface = */ ggml_backend_vk_reg_i,
|
||||
/* .context = */ nullptr,
|
||||
};
|
||||
try {
|
||||
ggml_vk_instance_init();
|
||||
return ®
|
||||
} catch (const vk::SystemError& e) {
|
||||
VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: System error: " << e.what());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return ®
|
||||
}
|
||||
|
||||
// Extension availability
|
||||
@@ -8394,7 +8499,7 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
|
||||
UNUSED(instance_extensions);
|
||||
}
|
||||
|
||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) {
|
||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
|
||||
switch (props.vendorID) {
|
||||
case VK_VENDOR_ID_INTEL:
|
||||
// Intel drivers don't support coopmat properly yet
|
||||
@@ -8402,10 +8507,7 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope
|
||||
case VK_VENDOR_ID_AMD:
|
||||
if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
|
||||
// Workaround for AMD proprietary driver reporting support on all GPUs
|
||||
const std::string name = props.deviceName;
|
||||
return name.rfind("AMD Radeon RX 7", 0) == 0 || name.rfind("AMD Radeon(TM) RX 7", 0) == 0 || // RDNA 3 consumer GPUs
|
||||
name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs
|
||||
name.rfind("AMD Radeon 7", 0) == 0 || name.rfind("AMD Radeon(TM) 7", 0) == 0; // RDNA 3 APUs
|
||||
return arch == vk_device_architecture::AMD_RDNA3;
|
||||
}
|
||||
return true;
|
||||
default:
|
||||
|
||||
@@ -1379,7 +1379,7 @@ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tenso
|
||||
(t0->nb[3] == t1->nb[3]);
|
||||
}
|
||||
|
||||
// check if t1 can be represented as a repetition of t0
|
||||
// check if t1 can be represented as a repeatition of t0
|
||||
bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user