mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-05 13:53:23 +02:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
11d47057a5 | ||
|
|
c421ac072d | ||
|
|
4ff7fe1fb3 | ||
|
|
6b8447352d | ||
|
|
674804a996 | ||
|
|
e94a138d64 | ||
|
|
e01c67affe | ||
|
|
994cfb1acb | ||
|
|
94008cc760 | ||
|
|
dbd5f2f573 | ||
|
|
f594bc80ba |
@@ -93,6 +93,7 @@ Typically finetunes of the base models below are supported as well.
|
||||
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)
|
||||
- [x] [Jais](https://huggingface.co/inceptionai/jais-13b-chat)
|
||||
- [x] [Bielik-11B-v2.3](https://huggingface.co/collections/speakleash/bielik-11b-v23-66ee813238d9b526a072408a)
|
||||
- [x] [RWKV-6](https://github.com/BlinkDL/RWKV-LM)
|
||||
|
||||
(instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md))
|
||||
|
||||
@@ -173,6 +174,7 @@ Unless otherwise noted these projects are open-source with permissive licensing:
|
||||
- [LARS - The LLM & Advanced Referencing Solution](https://github.com/abgulati/LARS) (AGPL)
|
||||
- [LLMUnity](https://github.com/undreamai/LLMUnity) (MIT)
|
||||
- [Llama Assistant](https://github.com/vietanhdev/llama-assistant) (GPL)
|
||||
- [PocketPal AI - An iOS and Android App](https://github.com/a-ghorbani/pocketpal-ai) (MIT)
|
||||
|
||||
*(to have a project listed here, it should clearly state that it depends on `llama.cpp`)*
|
||||
|
||||
|
||||
@@ -1097,7 +1097,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));
|
||||
add_opt(common_arg(
|
||||
{"--attention"}, "{causal,non,causal}",
|
||||
{"--attention"}, "{causal,non-causal}",
|
||||
"attention type for embeddings, use model default if unspecified",
|
||||
[](common_params & params, const std::string & value) {
|
||||
/**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; }
|
||||
@@ -1695,7 +1695,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_BENCH}));
|
||||
add_opt(common_arg(
|
||||
{"--embd-normalize"}, "N",
|
||||
string_format("normalisation for embendings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize),
|
||||
string_format("normalisation for embeddings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize),
|
||||
[](common_params & params, int value) {
|
||||
params.embd_normalize = value;
|
||||
}
|
||||
@@ -1709,7 +1709,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
|
||||
add_opt(common_arg(
|
||||
{"--embd-separator"}, "STRING",
|
||||
"separator of embendings (default \\n) for example \"<#sep#>\"",
|
||||
"separator of embeddings (default \\n) for example \"<#sep#>\"",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.embd_sep = value;
|
||||
}
|
||||
|
||||
@@ -1035,7 +1035,7 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
|
||||
return GGML_TYPE_Q5_1;
|
||||
}
|
||||
|
||||
throw std::runtime_error("Invalid cache type: " + s);
|
||||
throw std::runtime_error("Unsupported cache type: " + s);
|
||||
}
|
||||
|
||||
struct llama_context_params common_context_params_to_llama(const common_params & params) {
|
||||
@@ -1047,7 +1047,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
||||
cparams.n_ubatch = params.n_ubatch;
|
||||
cparams.n_threads = params.cpuparams.n_threads;
|
||||
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
|
||||
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
|
||||
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
|
||||
cparams.logits_all = params.logits_all;
|
||||
cparams.embeddings = params.embedding;
|
||||
cparams.rope_scaling_type = params.rope_scaling_type;
|
||||
|
||||
@@ -274,9 +274,9 @@ struct common_params {
|
||||
|
||||
// embedding
|
||||
bool embedding = false; // get only sentence embedding
|
||||
int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
||||
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
||||
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
|
||||
std::string embd_sep = "\n"; // separator of embendings
|
||||
std::string embd_sep = "\n"; // separator of embeddings
|
||||
bool reranking = false; // enable reranking support on server
|
||||
|
||||
// server params
|
||||
|
||||
@@ -2864,6 +2864,9 @@ class Rwkv6Model(Model):
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
|
||||
special_vocab.chat_template = "rwkv-world"
|
||||
# hack: Add '\n\n' as the EOT token to make it chat normally
|
||||
special_vocab._set_special_token("eot", 261)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
|
||||
@@ -348,6 +348,9 @@ if __name__ == '__main__':
|
||||
if ".base_layer.weight" in name:
|
||||
continue
|
||||
logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
|
||||
if ".embed_tokens.weight" in name or ".lm_head.weight" in name:
|
||||
logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning")
|
||||
logger.error("Hint: if you are using TRL, make sure not to call setup_chat_format()")
|
||||
sys.exit(1)
|
||||
|
||||
if base_name in tensor_map:
|
||||
|
||||
697
examples/llama.vim
Normal file
697
examples/llama.vim
Normal file
@@ -0,0 +1,697 @@
|
||||
" LLM-based text completion using llama.cpp
|
||||
"
|
||||
" requires:
|
||||
"
|
||||
" - neovim
|
||||
" - curl
|
||||
" - llama.cpp server instance
|
||||
" - FIM-compatible model
|
||||
"
|
||||
" sample config:
|
||||
"
|
||||
" - Tab - accept the current suggestion
|
||||
" - Shift+Tab - accept just the first line of the segguestion
|
||||
" - Ctrl+F - toggle FIM completion manually
|
||||
"
|
||||
" make symlink or copy this file to ~/.config/nvim/autoload/llama.vim
|
||||
"
|
||||
" start the llama.cpp server with a FIM-compatible model. for example:
|
||||
"
|
||||
" $ llama-server -m {model.gguf} --port 8012 -ngl 99 -fa -dt 0.1 --ubatch-size 512 --batch-size 1024 --cache-reuse 256
|
||||
"
|
||||
" --batch-size [512, model max context]
|
||||
"
|
||||
" adjust the batch size to control how much of the provided local context will be used during the inference
|
||||
" lower values will use smaller part of the context around the cursor, which will result in faster processing
|
||||
"
|
||||
" --ubatch-size [64, 2048]
|
||||
"
|
||||
" chunks the batch into smaller chunks for faster processing
|
||||
" depends on the specific hardware. use llama-bench to profile and determine the best size
|
||||
"
|
||||
" --cache-reuse (ge:llama_config.n_predict, 1024]
|
||||
"
|
||||
" this should be either 0 (disabled) or strictly larger than g:llama_config.n_predict
|
||||
" using non-zero value enables context reuse on the server side which dramatically improves the performance at
|
||||
" large contexts. a value of 256 should be good for all cases
|
||||
"
|
||||
" run this once to initialise llama.vim:
|
||||
"
|
||||
" :call llama#init()
|
||||
"
|
||||
" more info: https://github.com/ggerganov/llama.cpp/pull/9787
|
||||
"
|
||||
|
||||
" colors (adjust to your liking)
|
||||
highlight llama_hl_hint guifg=#ff772f
|
||||
highlight llama_hl_info guifg=#77ff2f
|
||||
|
||||
" general parameters:
|
||||
"
|
||||
" endpoint: llama.cpp server endpoint
|
||||
" n_prefix: number of lines before the cursor location to include in the local prefix
|
||||
" n_suffix: number of lines after the cursor location to include in the local suffix
|
||||
" n_predict: max number of tokens to predict
|
||||
" t_max_prompt_ms: max alloted time for the prompt processing (TODO: not yet supported)
|
||||
" t_max_predict_ms: max alloted time for the prediction
|
||||
" show_info: show extra info about the inference (0 - disabled, 1 - statusline, 2 - inline)
|
||||
" auto_fim: trigger FIM completion automatically on cursor movement
|
||||
" max_line_suffix: do not auto-trigger FIM completion if there are more than this number of characters to the right of the cursor
|
||||
"
|
||||
" ring buffer of chunks, accumulated with time upon:
|
||||
"
|
||||
" - completion request
|
||||
" - yank
|
||||
" - entering a buffer
|
||||
" - leaving a buffer
|
||||
" - writing a file
|
||||
"
|
||||
" parameters for the ring-buffer with extra context:
|
||||
"
|
||||
" ring_n_chunks: max number of chunks to pass as extra context to the server (0 to disable)
|
||||
" ring_chunk_size: max size of the chunks (in number of lines)
|
||||
" note: adjust these numbers so that you don't overrun your context
|
||||
" at ring_n_chunks = 64 and ring_chunk_size = 64 you need ~32k context
|
||||
" ring_scope: the range around the cursor position (in number of lines) for gathering chunks after FIM
|
||||
" ring_update_ms: how often to process queued chunks in normal mode
|
||||
"
|
||||
let s:default_config = {
|
||||
\ 'endpoint': 'http://127.0.0.1:8012/infill',
|
||||
\ 'n_prefix': 256,
|
||||
\ 'n_suffix': 64,
|
||||
\ 'n_predict': 128,
|
||||
\ 't_max_prompt_ms': 500,
|
||||
\ 't_max_predict_ms': 1000,
|
||||
\ 'show_info': 2,
|
||||
\ 'auto_fim': v:true,
|
||||
\ 'max_line_suffix': 8,
|
||||
\ 'ring_n_chunks': 64,
|
||||
\ 'ring_chunk_size': 64,
|
||||
\ 'ring_scope': 1024,
|
||||
\ 'ring_update_ms': 1000,
|
||||
\ }
|
||||
|
||||
let g:llama_config = get(g:, 'llama_config', s:default_config)
|
||||
|
||||
function! s:rand(i0, i1) abort
|
||||
return a:i0 + rand() % (a:i1 - a:i0 + 1)
|
||||
endfunction
|
||||
|
||||
function! llama#init()
|
||||
if !executable('curl')
|
||||
echohl WarningMsg
|
||||
echo 'llama.vim requires the "curl" command to be available'
|
||||
echohl None
|
||||
return
|
||||
endif
|
||||
|
||||
let s:pos_x = 0 " cursor position upon start of completion
|
||||
let s:pos_y = 0
|
||||
|
||||
let s:line_cur = ''
|
||||
|
||||
let s:line_cur_prefix = ''
|
||||
let s:line_cur_suffix = ''
|
||||
|
||||
let s:ring_chunks = [] " current set of chunks used as extra context
|
||||
let s:ring_queued = [] " chunks that are queued to be sent for processing
|
||||
let s:ring_n_evict = 0
|
||||
|
||||
let s:hint_shown = v:false
|
||||
let s:pos_y_pick = -9999 " last y where we picked a chunk
|
||||
let s:pos_dx = 0
|
||||
let s:content = []
|
||||
let s:can_accept = v:false
|
||||
|
||||
let s:timer_fim = -1
|
||||
let s:t_fim_start = reltime() " used to measure total FIM time
|
||||
let s:t_last_move = reltime() " last time the cursor moved
|
||||
|
||||
let s:current_job = v:null
|
||||
|
||||
augroup llama
|
||||
autocmd!
|
||||
autocmd InsertEnter * inoremap <expr> <silent> <C-F> llama#fim_inline(v:false)
|
||||
autocmd InsertLeavePre * call llama#fim_cancel()
|
||||
|
||||
autocmd CursorMoved * call s:on_move()
|
||||
autocmd CursorMovedI * call s:on_move()
|
||||
autocmd CompleteChanged * call llama#fim_cancel()
|
||||
|
||||
if g:llama_config.auto_fim
|
||||
autocmd CursorMovedI * call llama#fim(v:true)
|
||||
endif
|
||||
|
||||
" gather chunks upon yanking
|
||||
autocmd TextYankPost * if v:event.operator ==# 'y' | call s:pick_chunk(v:event.regcontents, v:false, v:true) | endif
|
||||
|
||||
" gather chunks upon entering/leaving a buffer
|
||||
autocmd BufEnter * call timer_start(100, {-> s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true)})
|
||||
autocmd BufLeave * call s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true)
|
||||
|
||||
" gather chunk upon saving the file
|
||||
autocmd BufWritePost * call s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true)
|
||||
augroup END
|
||||
|
||||
silent! call llama#fim_cancel()
|
||||
|
||||
" init background update of the ring buffer
|
||||
if g:llama_config.ring_n_chunks > 0
|
||||
call s:ring_update()
|
||||
endif
|
||||
endfunction
|
||||
|
||||
" compute how similar two chunks of text are
|
||||
" 0 - no similarity, 1 - high similarity
|
||||
" TODO: figure out something better
|
||||
function! s:chunk_sim(c0, c1)
|
||||
let l:lines0 = len(a:c0)
|
||||
let l:lines1 = len(a:c1)
|
||||
|
||||
let l:common = 0
|
||||
|
||||
for l:line0 in a:c0
|
||||
for l:line1 in a:c1
|
||||
if l:line0 == l:line1
|
||||
let l:common += 1
|
||||
break
|
||||
endif
|
||||
endfor
|
||||
endfor
|
||||
|
||||
return 2.0 * l:common / (l:lines0 + l:lines1)
|
||||
endfunction
|
||||
|
||||
" pick a random chunk of size g:llama_config.ring_chunk_size from the provided text and queue it for processing
|
||||
"
|
||||
" no_mod - do not pick chunks from buffers with pending changes
|
||||
" do_evict - evict chunks that are very similar to the new one
|
||||
"
|
||||
function! s:pick_chunk(text, no_mod, do_evict)
|
||||
" do not pick chunks from buffers with pending changes or buffers that are not files
|
||||
if a:no_mod && (getbufvar(bufnr('%'), '&modified') || !buflisted(bufnr('%')) || !filereadable(expand('%')))
|
||||
return
|
||||
endif
|
||||
|
||||
" if the extra context option is disabled - do nothing
|
||||
if g:llama_config.ring_n_chunks <= 0
|
||||
return
|
||||
endif
|
||||
|
||||
" don't pick very small chunks
|
||||
if len(a:text) < 3
|
||||
return
|
||||
endif
|
||||
|
||||
if len(a:text) + 1 < g:llama_config.ring_chunk_size
|
||||
let l:chunk = a:text
|
||||
else
|
||||
let l:l0 = s:rand(0, max([0, len(a:text) - g:llama_config.ring_chunk_size/2]))
|
||||
let l:l1 = min([l:l0 + g:llama_config.ring_chunk_size/2, len(a:text)])
|
||||
|
||||
let l:chunk = a:text[l:l0:l:l1]
|
||||
endif
|
||||
|
||||
let l:chunk_str = join(l:chunk, "\n") . "\n"
|
||||
|
||||
" check if this chunk is already added
|
||||
let l:exist = v:false
|
||||
|
||||
for i in range(len(s:ring_chunks))
|
||||
if s:ring_chunks[i].data == l:chunk
|
||||
let l:exist = v:true
|
||||
break
|
||||
endif
|
||||
endfor
|
||||
|
||||
for i in range(len(s:ring_queued))
|
||||
if s:ring_queued[i].data == l:chunk
|
||||
let l:exist = v:true
|
||||
break
|
||||
endif
|
||||
endfor
|
||||
|
||||
if l:exist
|
||||
return
|
||||
endif
|
||||
|
||||
" evict queued chunks that are very similar to the new one
|
||||
for i in range(len(s:ring_queued) - 1, 0, -1)
|
||||
if s:chunk_sim(s:ring_queued[i].data, l:chunk) > 0.9
|
||||
if a:do_evict
|
||||
call remove(s:ring_queued, i)
|
||||
let s:ring_n_evict += 1
|
||||
else
|
||||
return
|
||||
endif
|
||||
endif
|
||||
endfor
|
||||
|
||||
" also from s:ring_chunks
|
||||
for i in range(len(s:ring_chunks) - 1, 0, -1)
|
||||
if s:chunk_sim(s:ring_chunks[i].data, l:chunk) > 0.9
|
||||
if a:do_evict
|
||||
call remove(s:ring_chunks, i)
|
||||
let s:ring_n_evict += 1
|
||||
else
|
||||
return
|
||||
endif
|
||||
endif
|
||||
endfor
|
||||
|
||||
" TODO: become parameter ?
|
||||
if len(s:ring_queued) == 16
|
||||
call remove(s:ring_queued, 0)
|
||||
endif
|
||||
|
||||
call add(s:ring_queued, {'data': l:chunk, 'str': l:chunk_str, 'time': reltime(), 'filename': expand('%')})
|
||||
|
||||
"let &statusline = 'extra context: ' . len(s:ring_chunks) . ' / ' . len(s:ring_queued)
|
||||
endfunction
|
||||
|
||||
" picks a queued chunk, sends it for processing and adds it to s:ring_chunks
|
||||
" called every g:llama_config.ring_update_ms
|
||||
function! s:ring_update()
|
||||
call timer_start(g:llama_config.ring_update_ms, {-> s:ring_update()})
|
||||
|
||||
" update only if in normal mode or if the cursor hasn't moved for a while
|
||||
if mode() !=# 'n' && reltimefloat(reltime(s:t_last_move)) < 3.0
|
||||
return
|
||||
endif
|
||||
|
||||
if len(s:ring_queued) == 0
|
||||
return
|
||||
endif
|
||||
|
||||
" move the first queued chunk to the ring buffer
|
||||
if len(s:ring_chunks) == g:llama_config.ring_n_chunks
|
||||
call remove(s:ring_chunks, 0)
|
||||
endif
|
||||
|
||||
call add(s:ring_chunks, remove(s:ring_queued, 0))
|
||||
|
||||
"let &statusline = 'updated context: ' . len(s:ring_chunks) . ' / ' . len(s:ring_queued)
|
||||
|
||||
" send asynchronous job with the new extra context so that it is ready for the next FIM
|
||||
let l:extra_context = []
|
||||
for l:chunk in s:ring_chunks
|
||||
call add(l:extra_context, {
|
||||
\ 'text': l:chunk.str,
|
||||
\ 'time': l:chunk.time,
|
||||
\ 'filename': l:chunk.filename
|
||||
\ })
|
||||
endfor
|
||||
|
||||
" no samplers needed here
|
||||
let l:request = json_encode({
|
||||
\ 'input_prefix': "",
|
||||
\ 'input_suffix': "",
|
||||
\ 'input_extra': l:extra_context,
|
||||
\ 'prompt': "",
|
||||
\ 'n_predict': 1,
|
||||
\ 'temperature': 0.0,
|
||||
\ 'stream': v:false,
|
||||
\ 'samplers': ["temperature"],
|
||||
\ 'cache_prompt': v:true,
|
||||
\ 't_max_prompt_ms': 1,
|
||||
\ 't_max_predict_ms': 1
|
||||
\ })
|
||||
|
||||
let l:curl_command = printf(
|
||||
\ "curl --silent --no-buffer --request POST --url %s --header \"Content-Type: application/json\" --data %s",
|
||||
\ g:llama_config.endpoint, shellescape(l:request)
|
||||
\ )
|
||||
|
||||
" no callbacks because we don't need to process the response
|
||||
call jobstart(l:curl_command, {})
|
||||
endfunction
|
||||
|
||||
" necessary for 'inoremap <expr>'
|
||||
function! llama#fim_inline(is_auto) abort
|
||||
call llama#fim(a:is_auto)
|
||||
return ''
|
||||
endfunction
|
||||
|
||||
" the main FIM call
|
||||
" takes local context around the cursor and sends it together with the extra context to the server for completion
|
||||
function! llama#fim(is_auto) abort
|
||||
" we already have a suggestion for the current cursor position
|
||||
if s:hint_shown && !a:is_auto
|
||||
call llama#fim_cancel()
|
||||
return
|
||||
endif
|
||||
|
||||
call llama#fim_cancel()
|
||||
|
||||
" avoid sending repeated requests too fast
|
||||
if reltimefloat(reltime(s:t_fim_start)) < 0.6
|
||||
if s:timer_fim != -1
|
||||
call timer_stop(s:timer_fim)
|
||||
let s:timer_fim = -1
|
||||
endif
|
||||
|
||||
let s:t_fim_start = reltime()
|
||||
let s:timer_fim = timer_start(600, {-> llama#fim(v:true)})
|
||||
return
|
||||
endif
|
||||
|
||||
let s:t_fim_start = reltime()
|
||||
|
||||
let s:content = []
|
||||
let s:can_accept = v:false
|
||||
|
||||
let s:pos_x = col('.') - 1
|
||||
let s:pos_y = line('.')
|
||||
let l:max_y = line('$')
|
||||
|
||||
let l:lines_prefix = getline(max([1, s:pos_y - g:llama_config.n_prefix]), s:pos_y - 1)
|
||||
let l:lines_suffix = getline(s:pos_y + 1, min([l:max_y, s:pos_y + g:llama_config.n_suffix]))
|
||||
|
||||
let s:line_cur = getline('.')
|
||||
|
||||
let s:line_cur_prefix = strpart(s:line_cur, 0, s:pos_x)
|
||||
let s:line_cur_suffix = strpart(s:line_cur, s:pos_x)
|
||||
|
||||
if a:is_auto && len(s:line_cur_suffix) > g:llama_config.max_line_suffix
|
||||
return
|
||||
endif
|
||||
|
||||
let l:prefix = ""
|
||||
\ . join(l:lines_prefix, "\n")
|
||||
\ . "\n"
|
||||
|
||||
let l:prompt = ""
|
||||
\ . s:line_cur_prefix
|
||||
|
||||
let l:suffix = ""
|
||||
\ . s:line_cur_suffix
|
||||
\ . "\n"
|
||||
\ . join(l:lines_suffix, "\n")
|
||||
\ . "\n"
|
||||
|
||||
" prepare the extra context data
|
||||
let l:extra_context = []
|
||||
for l:chunk in s:ring_chunks
|
||||
call add(l:extra_context, {
|
||||
\ 'text': l:chunk.str,
|
||||
\ 'time': l:chunk.time,
|
||||
\ 'filename': l:chunk.filename
|
||||
\ })
|
||||
endfor
|
||||
|
||||
" the indentation of the current line
|
||||
let l:indent = strlen(matchstr(s:line_cur_prefix, '^\s*'))
|
||||
|
||||
let l:request = json_encode({
|
||||
\ 'input_prefix': l:prefix,
|
||||
\ 'input_suffix': l:suffix,
|
||||
\ 'input_extra': l:extra_context,
|
||||
\ 'prompt': l:prompt,
|
||||
\ 'n_predict': g:llama_config.n_predict,
|
||||
\ 'n_indent': l:indent,
|
||||
\ 'top_k': 40,
|
||||
\ 'top_p': 0.99,
|
||||
\ 'stream': v:false,
|
||||
\ 'samplers': ["top_k", "top_p", "infill"],
|
||||
\ 'cache_prompt': v:true,
|
||||
\ 't_max_prompt_ms': g:llama_config.t_max_prompt_ms,
|
||||
\ 't_max_predict_ms': g:llama_config.t_max_predict_ms
|
||||
\ })
|
||||
|
||||
let l:curl_command = printf(
|
||||
\ "curl --silent --no-buffer --request POST --url %s --header \"Content-Type: application/json\" --data %s",
|
||||
\ g:llama_config.endpoint, shellescape(l:request)
|
||||
\ )
|
||||
|
||||
if s:current_job != v:null
|
||||
call jobstop(s:current_job)
|
||||
endif
|
||||
|
||||
" send the request asynchronously
|
||||
let s:current_job = jobstart(l:curl_command, {
|
||||
\ 'on_stdout': function('s:fim_on_stdout'),
|
||||
\ 'on_exit': function('s:fim_on_exit'),
|
||||
\ 'stdout_buffered': v:true,
|
||||
\ 'pos_x': s:pos_x,
|
||||
\ 'pos_y': s:pos_y,
|
||||
\ 'is_auto': a:is_auto
|
||||
\ })
|
||||
|
||||
" TODO: per-file location
|
||||
let l:delta_y = abs(s:pos_y - s:pos_y_pick)
|
||||
|
||||
" gather some extra context nearby and process it in the background
|
||||
" only gather chunks if the cursor has moved a lot
|
||||
" TODO: something more clever? reranking?
|
||||
if a:is_auto && l:delta_y > 32
|
||||
" expand the prefix even further
|
||||
call s:pick_chunk(getline(max([1, s:pos_y - g:llama_config.ring_scope]), max([1, s:pos_y - g:llama_config.n_prefix])), v:false, v:false)
|
||||
|
||||
" pick a suffix chunk
|
||||
call s:pick_chunk(getline(min([l:max_y, s:pos_y + g:llama_config.n_suffix]), min([l:max_y, s:pos_y + g:llama_config.n_suffix + g:llama_config.ring_chunk_size])), v:false, v:false)
|
||||
|
||||
let s:pos_y_pick = s:pos_y
|
||||
endif
|
||||
endfunction
|
||||
|
||||
" if first_line == v:true accept only the first line of the response
|
||||
function! llama#fim_accept(first_line)
|
||||
" insert the suggestion at the cursor location
|
||||
if s:can_accept && len(s:content) > 0
|
||||
call setline(s:pos_y, s:line_cur[:(s:pos_x - 1)] . s:content[0])
|
||||
if len(s:content) > 1
|
||||
if !a:first_line
|
||||
call append(s:pos_y, s:content[1:-1])
|
||||
endif
|
||||
endif
|
||||
|
||||
" move the cursor to the end of the accepted text
|
||||
if !a:first_line && len(s:content) > 1
|
||||
call cursor(s:pos_y + len(s:content) - 1, s:pos_x + s:pos_dx + 1)
|
||||
else
|
||||
call cursor(s:pos_y, s:pos_x + len(s:content[0]))
|
||||
endif
|
||||
endif
|
||||
|
||||
call llama#fim_cancel()
|
||||
endfunction
|
||||
|
||||
function! llama#fim_cancel()
|
||||
let s:hint_shown = v:false
|
||||
|
||||
" clear the virtual text
|
||||
let l:bufnr = bufnr('%')
|
||||
|
||||
let l:id_vt_fim = nvim_create_namespace('vt_fim')
|
||||
|
||||
call nvim_buf_clear_namespace(l:bufnr, l:id_vt_fim, 0, -1)
|
||||
|
||||
" remove the mappings
|
||||
silent! iunmap <buffer> <Tab>
|
||||
silent! iunmap <buffer> <S-Tab>
|
||||
silent! iunmap <buffer> <Esc>
|
||||
endfunction
|
||||
|
||||
function! s:on_move()
|
||||
let s:t_last_move = reltime()
|
||||
|
||||
call llama#fim_cancel()
|
||||
endfunction
|
||||
|
||||
" callback that processes the FIM result from the server and displays the suggestion
|
||||
function! s:fim_on_stdout(job_id, data, event) dict
|
||||
let l:raw = join(a:data, "\n")
|
||||
if len(l:raw) == 0
|
||||
return
|
||||
endif
|
||||
|
||||
if self.pos_x != col('.') - 1 || self.pos_y != line('.')
|
||||
return
|
||||
endif
|
||||
|
||||
" show the suggestion only in insert mode
|
||||
if mode() !=# 'i'
|
||||
return
|
||||
endif
|
||||
|
||||
let s:pos_x = self.pos_x
|
||||
let s:pos_y = self.pos_y
|
||||
|
||||
let s:can_accept = v:true
|
||||
let l:has_info = v:false
|
||||
|
||||
if s:can_accept && v:shell_error
|
||||
if !self.is_auto
|
||||
call add(s:content, "<| curl error: is the server on? |>")
|
||||
endif
|
||||
let s:can_accept = v:false
|
||||
endif
|
||||
|
||||
let l:n_prompt = 0
|
||||
let l:t_prompt_ms = 1.0
|
||||
let l:s_prompt = 0
|
||||
|
||||
let l:n_predict = 0
|
||||
let l:t_predict_ms = 1.0
|
||||
let l:s_predict = 0
|
||||
|
||||
" get the generated suggestion
|
||||
if s:can_accept
|
||||
let l:response = json_decode(l:raw)
|
||||
|
||||
for l:part in split(get(l:response, 'content', ''), "\n", 1)
|
||||
call add(s:content, l:part)
|
||||
endfor
|
||||
|
||||
" remove trailing new lines
|
||||
while len(s:content) > 0 && s:content[-1] == ""
|
||||
call remove(s:content, -1)
|
||||
endwhile
|
||||
|
||||
let l:generation_settings = get(l:response, 'generation_settings', {})
|
||||
let l:n_ctx = get(l:generation_settings, 'n_ctx', 0)
|
||||
|
||||
let l:n_cached = get(l:response, 'tokens_cached', 0)
|
||||
let l:truncated = get(l:response, 'truncated', v:false)
|
||||
|
||||
" if response.timings is available
|
||||
if len(get(l:response, 'timings', {})) > 0
|
||||
let l:has_info = v:true
|
||||
let l:timings = get(l:response, 'timings', {})
|
||||
|
||||
let l:n_prompt = get(l:timings, 'prompt_n', 0)
|
||||
let l:t_prompt_ms = get(l:timings, 'prompt_ms', 1)
|
||||
let l:s_prompt = get(l:timings, 'prompt_per_second', 0)
|
||||
|
||||
let l:n_predict = get(l:timings, 'predicted_n', 0)
|
||||
let l:t_predict_ms = get(l:timings, 'predicted_ms', 1)
|
||||
let l:s_predict = get(l:timings, 'predicted_per_second', 0)
|
||||
endif
|
||||
endif
|
||||
|
||||
if len(s:content) == 0
|
||||
call add(s:content, "")
|
||||
let s:can_accept = v:false
|
||||
endif
|
||||
|
||||
if len(s:content) == 0
|
||||
return
|
||||
endif
|
||||
|
||||
" NOTE: the following is logic for discarding predictions that repeat existing text
|
||||
" the code is quite ugly and there is very likely a simpler and more canonical way to implement this
|
||||
"
|
||||
" still, I wonder if there is some better way that avoids having to do these special hacks?
|
||||
" on one hand, the LLM 'sees' the contents of the file before we start editing, so it is normal that it would
|
||||
" start generating whatever we have given it via the extra context. but on the other hand, it's not very
|
||||
" helpful to re-generate the same code that is already there
|
||||
|
||||
" truncate the suggestion if the first line is empty
|
||||
if len(s:content) == 1 && s:content[0] == ""
|
||||
let s:content = [""]
|
||||
endif
|
||||
|
||||
" ... and the next lines are repeated
|
||||
if len(s:content) > 1 && s:content[0] == "" && s:content[1:] == getline(s:pos_y + 1, s:pos_y + len(s:content) - 1)
|
||||
let s:content = [""]
|
||||
endif
|
||||
|
||||
" truncate the suggestion if it repeats the suffix
|
||||
if len(s:content) == 1 && s:content[0] == s:line_cur_suffix
|
||||
let s:content = [""]
|
||||
endif
|
||||
|
||||
" find the first non-empty line (strip whitespace)
|
||||
let l:cmp_y = s:pos_y + 1
|
||||
while l:cmp_y < line('$') && getline(l:cmp_y) =~? '^\s*$'
|
||||
let l:cmp_y += 1
|
||||
endwhile
|
||||
|
||||
if (s:line_cur_prefix . s:content[0]) == getline(l:cmp_y)
|
||||
" truncate the suggestion if it repeats the next line
|
||||
if len(s:content) == 1
|
||||
let s:content = [""]
|
||||
endif
|
||||
|
||||
" ... or if the second line of the suggestion is the prefix of line l:cmp_y + 1
|
||||
if len(s:content) == 2 && s:content[-1] == getline(l:cmp_y + 1)[:len(s:content[-1]) - 1]
|
||||
let s:content = [""]
|
||||
endif
|
||||
|
||||
" ... or if the middle chunk of lines of the suggestion is the same as [l:cmp_y + 1, l:cmp_y + len(s:content) - 1)
|
||||
if len(s:content) > 2 && join(s:content[1:-1], "\n") == join(getline(l:cmp_y + 1, l:cmp_y + len(s:content) - 1), "\n")
|
||||
let s:content = [""]
|
||||
endif
|
||||
endif
|
||||
|
||||
" keep only lines that have the same or larger whitespace prefix as s:line_cur_prefix
|
||||
"let l:indent = strlen(matchstr(s:line_cur_prefix, '^\s*'))
|
||||
"for i in range(1, len(s:content) - 1)
|
||||
" if strlen(matchstr(s:content[i], '^\s*')) < l:indent
|
||||
" let s:content = s:content[:i - 1]
|
||||
" break
|
||||
" endif
|
||||
"endfor
|
||||
|
||||
let s:pos_dx = len(s:content[-1])
|
||||
|
||||
let s:content[-1] .= s:line_cur_suffix
|
||||
|
||||
call llama#fim_cancel()
|
||||
|
||||
" display virtual text with the suggestion
|
||||
let l:bufnr = bufnr('%')
|
||||
|
||||
let l:id_vt_fim = nvim_create_namespace('vt_fim')
|
||||
|
||||
" construct the info message
|
||||
if g:llama_config.show_info > 0 && l:has_info
|
||||
let l:prefix = ' '
|
||||
|
||||
if l:truncated
|
||||
let l:info = printf("%s | WARNING: the context is full: %d / %d, increase the server context size or reduce g:llama_config.ring_n_chunks",
|
||||
\ g:llama_config.show_info == 2 ? l:prefix : 'llama.vim',
|
||||
\ l:n_cached, l:n_ctx
|
||||
\ )
|
||||
else
|
||||
let l:info = printf("%s | c: %d / %d, r: %d / %d, e: %d, q: %d / 16 | p: %d (%.2f ms, %.2f t/s) | g: %d (%.2f ms, %.2f t/s) | t: %.2f ms",
|
||||
\ g:llama_config.show_info == 2 ? l:prefix : 'llama.vim',
|
||||
\ l:n_cached, l:n_ctx, len(s:ring_chunks), g:llama_config.ring_n_chunks, s:ring_n_evict, len(s:ring_queued),
|
||||
\ l:n_prompt, l:t_prompt_ms, l:s_prompt,
|
||||
\ l:n_predict, l:t_predict_ms, l:s_predict,
|
||||
\ 1000.0 * reltimefloat(reltime(s:t_fim_start))
|
||||
\ )
|
||||
endif
|
||||
|
||||
if g:llama_config.show_info == 1
|
||||
" display the info in the statusline
|
||||
let &statusline = l:info
|
||||
let l:info = ''
|
||||
endif
|
||||
endif
|
||||
|
||||
" display the suggestion and append the info to the end of the first line
|
||||
call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, s:pos_x - 1, {
|
||||
\ 'virt_text': [[s:content[0], 'llama_hl_hint'], [l:info, 'llama_hl_info']],
|
||||
\ 'virt_text_win_col': virtcol('.') - 1
|
||||
\ })
|
||||
|
||||
call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, 0, {
|
||||
\ 'virt_lines': map(s:content[1:], {idx, val -> [[val, 'llama_hl_hint']]}),
|
||||
\ 'virt_text_win_col': virtcol('.')
|
||||
\ })
|
||||
|
||||
" setup accept shortcuts
|
||||
inoremap <buffer> <Tab> <C-O>:call llama#fim_accept(v:false)<CR>
|
||||
inoremap <buffer> <S-Tab> <C-O>:call llama#fim_accept(v:true)<CR>
|
||||
|
||||
let s:hint_shown = v:true
|
||||
endfunction
|
||||
|
||||
function! s:fim_on_exit(job_id, exit_code, event) dict
|
||||
if a:exit_code != 0
|
||||
echom "Job failed with exit code: " . a:exit_code
|
||||
endif
|
||||
|
||||
let s:current_job = v:null
|
||||
endfunction
|
||||
@@ -34,6 +34,8 @@ extern "C" {
|
||||
*/
|
||||
#define GGML_CANN_MAX_DEVICES 16
|
||||
|
||||
GGML_API ggml_backend_reg_t ggml_backend_cann_reg(void);
|
||||
|
||||
/**
|
||||
* @brief Initializes the CANN backend for a specified device.
|
||||
*
|
||||
|
||||
@@ -561,6 +561,10 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
|
||||
# include "ggml-amx.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_CANN
|
||||
#include "ggml-cann.h"
|
||||
#endif
|
||||
|
||||
struct ggml_backend_registry {
|
||||
std::vector<ggml_backend_reg_t> backends;
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
@@ -587,8 +591,11 @@ struct ggml_backend_registry {
|
||||
#ifdef GGML_USE_AMX
|
||||
register_backend(ggml_backend_amx_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_CANN
|
||||
register_backend(ggml_backend_cann_reg());
|
||||
#endif
|
||||
|
||||
// TODO: kompute, cann
|
||||
// TODO: kompute
|
||||
|
||||
register_backend(ggml_backend_cpu_reg());
|
||||
}
|
||||
|
||||
@@ -39,6 +39,8 @@
|
||||
|
||||
#include "ggml-common.h"
|
||||
|
||||
#define GGML_CANN_NAME "CANN"
|
||||
|
||||
/**
|
||||
* @brief Handles CANN errors by printing an error message and aborting.
|
||||
*
|
||||
@@ -851,13 +853,6 @@ static void ggml_backend_cann_buffer_set_tensor(
|
||||
void *transform_buffer = malloc(size);
|
||||
ggml_backend_cann_transform(tensor, data, transform_buffer);
|
||||
|
||||
#ifndef NDEBUG
|
||||
void *check_buffer = malloc(size);
|
||||
ggml_backend_cann_transform_back(tensor, transform_buffer,
|
||||
check_buffer);
|
||||
GGML_ASSERT(memcmp(data, check_buffer, size) == 0);
|
||||
free(check_buffer);
|
||||
#endif
|
||||
ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size,
|
||||
transform_buffer, size,
|
||||
ACL_MEMCPY_HOST_TO_DEVICE));
|
||||
@@ -969,7 +964,7 @@ static void ggml_backend_cann_buffer_clear(
|
||||
* This structure defines function pointers to operations that can be performed
|
||||
* on a CANN buffer within the backend.
|
||||
*/
|
||||
static ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
|
||||
static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
|
||||
/* .get_name = */ ggml_backend_cann_buffer_get_name,
|
||||
/* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
|
||||
/* .get_base = */ ggml_backend_cann_buffer_get_base,
|
||||
@@ -1105,19 +1100,25 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static bool ggml_backend_cann_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
||||
return false;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Interface for managing CANN buffer types in the GGML backend.
|
||||
*
|
||||
* Provides function pointers for allocating, querying properties, and managing
|
||||
* memory for CANN buffer types in the GGML backend.
|
||||
*/
|
||||
static ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
|
||||
static const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
|
||||
/* .get_name = */ ggml_backend_cann_buffer_type_name,
|
||||
/* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment,
|
||||
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
|
||||
/* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size,
|
||||
/* .is_host = */ NULL,
|
||||
/* .is_host = */ ggml_backend_cann_buffer_type_is_host,
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -1148,7 +1149,7 @@ ggml_backend_cann_buffer_type(int32_t device) {
|
||||
for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) {
|
||||
ggml_backend_cann_buffer_types[i] = {
|
||||
/* .iface = */ ggml_backend_cann_buffer_type_interface,
|
||||
/* .device = */ nullptr,
|
||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
|
||||
/* .context = */
|
||||
new ggml_backend_cann_buffer_type_context{
|
||||
i, "CANN" + std::to_string(i)},
|
||||
@@ -1264,7 +1265,7 @@ ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
|
||||
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
|
||||
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
|
||||
},
|
||||
/* .device = */ nullptr,
|
||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
|
||||
/* .context = */ nullptr,
|
||||
};
|
||||
|
||||
@@ -1511,13 +1512,6 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
|
||||
void *transform_buffer = malloc(size);
|
||||
ggml_backend_cann_transform(tensor, data, transform_buffer);
|
||||
|
||||
#ifndef NDEBUG
|
||||
void *check_buffer = malloc(size);
|
||||
ggml_backend_cann_transform_back(tensor, transform_buffer,
|
||||
check_buffer);
|
||||
GGML_ASSERT(memcmp(data, check_buffer, size));
|
||||
free(check_buffer);
|
||||
#endif
|
||||
ACL_CHECK(aclrtMemcpyAsync(
|
||||
(char *)tensor->data + offset, size, transform_buffer, size,
|
||||
ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
|
||||
@@ -1692,7 +1686,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(
|
||||
* @return bool Returns true if the operation is supported by the backend,
|
||||
* otherwise false.
|
||||
*/
|
||||
static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
|
||||
static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
const ggml_tensor* op) {
|
||||
switch (op->op) {
|
||||
case GGML_OP_UNARY:
|
||||
@@ -1783,7 +1777,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
|
||||
return false;
|
||||
}
|
||||
|
||||
GGML_UNUSED(backend);
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1801,31 +1795,6 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
|
||||
return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Checks if the CANN backend supports a specific backend buffer type.
|
||||
*
|
||||
* This function determines whether the CANN backend supports the given backend
|
||||
* buffer type by comparing the device context of the backend and buffer type.
|
||||
* It returns true if the devices are same between the backend context and
|
||||
* buffer type context.
|
||||
*
|
||||
* @param backend Pointer to the CANN backend.
|
||||
* @param buft Pointer to the backend buffer type to check.
|
||||
* @return bool Returns true if the CANN backend supports the buffer type,
|
||||
* otherwise false.
|
||||
*/
|
||||
static bool ggml_backend_cann_supports_buft(
|
||||
ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
||||
if (ggml_backend_buft_is_cann(buft)) {
|
||||
ggml_backend_cann_context * cann_ctx =
|
||||
(ggml_backend_cann_context *)backend->context;
|
||||
ggml_backend_cann_buffer_type_context * buft_ctx =
|
||||
(ggml_backend_cann_buffer_type_context *)buft->context;
|
||||
return buft_ctx->device == cann_ctx->device;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Determines if a tensor operation should be offloaded to the CANN
|
||||
* backend.
|
||||
@@ -1840,54 +1809,14 @@ static bool ggml_backend_cann_supports_buft(
|
||||
* @return bool Returns true if the operation should be offloaded, otherwise
|
||||
* false.
|
||||
*/
|
||||
static bool ggml_backend_cann_offload_op(ggml_backend_t backend,
|
||||
static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev,
|
||||
const ggml_tensor* op) {
|
||||
const int min_batch_size = 32;
|
||||
GGML_UNUSED(backend);
|
||||
GGML_UNUSED(dev);
|
||||
|
||||
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Creates a new event for the CANN backend.
|
||||
*
|
||||
* This function initializes a new event for the CANN backend by setting the
|
||||
* device and creating an ACL runtime event. The created event is then wrapped
|
||||
* in a ggml_backend_event structure and returned.
|
||||
*
|
||||
* @param backend Pointer to the CANN backend.
|
||||
* @return ggml_backend_event_t Returns a pointer to the new event structure.
|
||||
*/
|
||||
static ggml_backend_event_t ggml_backend_cann_event_new(
|
||||
ggml_backend_t backend) {
|
||||
ggml_backend_cann_context* cann_ctx =
|
||||
(ggml_backend_cann_context*)backend->context;
|
||||
|
||||
ggml_cann_set_device(cann_ctx->device);
|
||||
|
||||
aclrtEvent event;
|
||||
ACL_CHECK(aclrtCreateEvent(&event));
|
||||
|
||||
return new ggml_backend_event{
|
||||
/* .device = */ nullptr,
|
||||
/* .context = */ event,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Frees a CANN backend event.
|
||||
*
|
||||
* This function destroys the ACL runtime event associated with the given CANN
|
||||
* backend event and then deletes the event structure itself.
|
||||
*
|
||||
* @param event Pointer to the event structure to be freed.
|
||||
*/
|
||||
static void ggml_backend_cann_event_free(ggml_backend_event_t event) {
|
||||
ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
|
||||
|
||||
delete event;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Records an event on the CANN backend stream.
|
||||
*
|
||||
@@ -1924,17 +1853,6 @@ static void ggml_backend_cann_event_wait(ggml_backend_t backend,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Synchronizes the given event on the CANN backend.
|
||||
*
|
||||
* This function waits for the specified event to complete on the ACL runtime.
|
||||
*
|
||||
* @param event Pointer to the event structure to be synchronized.
|
||||
*/
|
||||
static void ggml_backend_cann_event_synchronize(ggml_backend_event_t event) {
|
||||
ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Structure defining the interface for the CANN backend.
|
||||
*
|
||||
@@ -1942,7 +1860,7 @@ static void ggml_backend_cann_event_synchronize(ggml_backend_event_t event) {
|
||||
* supported by the CANN backend, including name retrieval, memory
|
||||
* management, tensor operations, synchronization, and event handling.
|
||||
*/
|
||||
static ggml_backend_i ggml_backend_cann_interface = {
|
||||
static const ggml_backend_i ggml_backend_cann_interface = {
|
||||
/* .get_name = */ ggml_backend_cann_name,
|
||||
/* .free = */ ggml_backend_cann_free,
|
||||
/* .get_default_buffer_type = */ ggml_backend_cann_get_default_buffer_type,
|
||||
@@ -1955,9 +1873,9 @@ static ggml_backend_i ggml_backend_cann_interface = {
|
||||
/* .graph_plan_update = */ NULL,
|
||||
/* .graph_plan_compute = */ NULL,
|
||||
/* .graph_compute = */ ggml_backend_cann_graph_compute,
|
||||
/* .supports_op = */ ggml_backend_cann_supports_op,
|
||||
/* .supports_buft = */ ggml_backend_cann_supports_buft,
|
||||
/* .offload_op = */ ggml_backend_cann_offload_op,
|
||||
/* .supports_op = */ NULL, // moved to device
|
||||
/* .supports_buft = */ NULL, // moved to device
|
||||
/* .offload_op = */ NULL, // moved to device
|
||||
/* .event_record = */ ggml_backend_cann_event_record,
|
||||
/* .event_wait = */ ggml_backend_cann_event_wait,
|
||||
};
|
||||
@@ -1976,6 +1894,234 @@ static ggml_guid_t ggml_backend_cann_guid() {
|
||||
return &guid;
|
||||
}
|
||||
|
||||
// backend device
|
||||
struct ggml_backend_cann_device_context {
|
||||
int device;
|
||||
std::string name;
|
||||
std::string description;
|
||||
};
|
||||
|
||||
static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {
|
||||
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||
return ctx->name.c_str();
|
||||
}
|
||||
|
||||
static const char* ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {
|
||||
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||
return ctx->description.c_str();
|
||||
}
|
||||
|
||||
static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||
ggml_backend_cann_get_device_memory(ctx->device, free, total);
|
||||
}
|
||||
|
||||
static enum ggml_backend_dev_type ggml_backend_cann_device_get_type(ggml_backend_dev_t dev) {
|
||||
GGML_UNUSED(dev);
|
||||
return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
|
||||
}
|
||||
|
||||
static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
||||
props->name = ggml_backend_cann_device_get_name(dev);
|
||||
props->description = ggml_backend_cann_device_get_description(dev);
|
||||
props->type = ggml_backend_cann_device_get_type(dev);
|
||||
ggml_backend_cann_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
|
||||
bool host_buffer = getenv("GGML_CANN_NO_PINNED") == nullptr;
|
||||
|
||||
props->caps = {
|
||||
/* .async = */ false,
|
||||
/* .host_buffer = */ host_buffer,
|
||||
/* .buffer_from_host_ptr = */ false,
|
||||
/* .events = */ true,
|
||||
};
|
||||
}
|
||||
|
||||
static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||
GGML_UNUSED(params);
|
||||
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||
return ggml_backend_cann_init(ctx->device);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Checks if the CANN backend supports a specific backend buffer type.
|
||||
*
|
||||
* This function determines whether the CANN backend supports the given backend
|
||||
* buffer type by comparing the device context of the backend and buffer type.
|
||||
* It returns true if the devices are same between the backend context and
|
||||
* buffer type context.
|
||||
*
|
||||
* @param backend Pointer to the CANN backend.
|
||||
* @param buft Pointer to the backend buffer type to check.
|
||||
* @return bool Returns true if the CANN backend supports the buffer type,
|
||||
* otherwise false.
|
||||
*/
|
||||
static bool ggml_backend_cann_supports_buft(
|
||||
ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||
if (ggml_backend_buft_is_cann(buft)) {
|
||||
ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||
ggml_backend_cann_buffer_type_context * buft_ctx =
|
||||
(ggml_backend_cann_buffer_type_context *)buft->context;
|
||||
return buft_ctx->device == dev_ctx->device;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) {
|
||||
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||
return ggml_backend_cann_buffer_type(ctx->device);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(ggml_backend_dev_t dev) {
|
||||
GGML_UNUSED(dev);
|
||||
return ggml_backend_cann_host_buffer_type();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Creates a new event for the CANN backend device.
|
||||
*
|
||||
* This function initializes a new event for the CANN backend by setting the
|
||||
* device and creating an ACL runtime event. The created event is then wrapped
|
||||
* in a ggml_backend_event structure and returned.
|
||||
*
|
||||
* @param backend Pointer to the CANN backend.
|
||||
* @return ggml_backend_event_t Returns a pointer to the new event structure.
|
||||
*/
|
||||
static ggml_backend_event_t ggml_backend_cann_device_event_new(
|
||||
ggml_backend_dev_t dev) {
|
||||
ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||
|
||||
ggml_cann_set_device(dev_ctx->device);
|
||||
|
||||
aclrtEvent event;
|
||||
ACL_CHECK(aclrtCreateEvent(&event));
|
||||
|
||||
return new ggml_backend_event{
|
||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), dev_ctx->device),
|
||||
/* .context = */ event,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Frees a CANN backend event.
|
||||
*
|
||||
* This function destroys the ACL runtime event associated with the given CANN
|
||||
* backend event and then deletes the event structure itself.
|
||||
*
|
||||
* @param event Pointer to the event structure to be freed.
|
||||
*/
|
||||
static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
|
||||
ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
|
||||
|
||||
delete event;
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Synchronizes the given event on the CANN backend.
|
||||
*
|
||||
* This function waits for the specified event to complete on the ACL runtime.
|
||||
*
|
||||
* @param event Pointer to the event structure to be synchronized.
|
||||
*/
|
||||
static void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
|
||||
ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static const ggml_backend_device_i ggml_backend_cann_device_interface = {
|
||||
/* .get_name = */ ggml_backend_cann_device_get_name,
|
||||
/* .get_description = */ ggml_backend_cann_device_get_description,
|
||||
/* .get_memory = */ ggml_backend_cann_device_get_memory,
|
||||
/* .get_type = */ ggml_backend_cann_device_get_type,
|
||||
/* .get_props = */ ggml_backend_cann_device_get_props,
|
||||
/* .init_backend = */ ggml_backend_cann_device_init, // called for every card
|
||||
/* .get_buffer_type = */ ggml_backend_cann_device_get_buffer_type,
|
||||
/* .get_host_buffer_type = */ ggml_backend_cann_device_get_host_buffer_type,
|
||||
/* .buffer_from_host_ptr = */ NULL, // not supported for CANN
|
||||
/* .supports_op = */ ggml_backend_cann_supports_op,
|
||||
/* .supports_buft = */ ggml_backend_cann_supports_buft,
|
||||
/* .offload_op = */ ggml_backend_cann_offload_op,
|
||||
/* .event_new = */ ggml_backend_cann_device_event_new,
|
||||
/* .event_free = */ ggml_backend_cann_device_event_free,
|
||||
/* .event_synchronize = */ ggml_backend_cann_device_event_synchronize,
|
||||
};
|
||||
|
||||
|
||||
// backend reg
|
||||
struct ggml_backend_cann_reg_context {
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
};
|
||||
|
||||
static const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) {
|
||||
GGML_UNUSED(reg);
|
||||
return GGML_CANN_NAME;
|
||||
}
|
||||
|
||||
static size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) {
|
||||
ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
|
||||
return ctx->devices.size();
|
||||
}
|
||||
|
||||
static ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
||||
ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
|
||||
GGML_ASSERT(index < ctx->devices.size());
|
||||
return ctx->devices[index];
|
||||
}
|
||||
|
||||
static void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
||||
GGML_UNUSED(reg);
|
||||
GGML_UNUSED(name);
|
||||
// reserved for future use
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static const ggml_backend_reg_i ggml_backend_cann_reg_interface = {
|
||||
/* .get_name = */ ggml_backend_cann_reg_get_name,
|
||||
/* .get_device_count = */ ggml_backend_cann_reg_get_device_count,
|
||||
/* .get_device_get = */ ggml_backend_cann_reg_get_device,
|
||||
/* .get_proc_address = */ ggml_backend_cann_reg_get_proc_address,
|
||||
};
|
||||
|
||||
// backend registry, called only once for cann backend
|
||||
ggml_backend_reg_t ggml_backend_cann_reg() {
|
||||
static ggml_backend_reg reg;
|
||||
static bool initialized = false;
|
||||
|
||||
{
|
||||
static std::mutex mutex;
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
if (!initialized) {
|
||||
aclInit(nullptr);
|
||||
ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;
|
||||
|
||||
for (int i = 0; i < ggml_cann_info().device_count; i++) {
|
||||
ggml_backend_cann_device_context* dev_ctx = new ggml_backend_cann_device_context();
|
||||
dev_ctx->description = aclrtGetSocName();
|
||||
dev_ctx->device = i;
|
||||
dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
|
||||
ggml_cann_set_device(i);
|
||||
ggml_backend_dev_t dev = new ggml_backend_device {
|
||||
/* .interface = */ ggml_backend_cann_device_interface,
|
||||
/* .reg = */ ®,
|
||||
/* .context = */ dev_ctx
|
||||
};
|
||||
ctx->devices.push_back(dev);
|
||||
}
|
||||
|
||||
reg = ggml_backend_reg {
|
||||
/* .interface = */ ggml_backend_cann_reg_interface,
|
||||
/* .context = */ ctx
|
||||
};
|
||||
}
|
||||
|
||||
initialized = true;
|
||||
}
|
||||
|
||||
return ®
|
||||
}
|
||||
|
||||
ggml_backend_t ggml_backend_cann_init(int32_t device) {
|
||||
aclInit(nullptr);
|
||||
if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
|
||||
@@ -1992,7 +2138,7 @@ ggml_backend_t ggml_backend_cann_init(int32_t device) {
|
||||
ggml_backend_t cann_backend =
|
||||
new ggml_backend{/* .guid = */ ggml_backend_cann_guid(),
|
||||
/* .interface = */ ggml_backend_cann_interface,
|
||||
/* .device = */ nullptr,
|
||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
|
||||
/* .context = */ ctx};
|
||||
|
||||
return cann_backend;
|
||||
|
||||
@@ -324,8 +324,9 @@ struct ggml_logger_state {
|
||||
static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL};
|
||||
|
||||
static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
|
||||
if (format == NULL)
|
||||
if (format == NULL) {
|
||||
return;
|
||||
}
|
||||
va_list args_copy;
|
||||
va_copy(args_copy, args);
|
||||
char buffer[128];
|
||||
@@ -15723,6 +15724,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
|
||||
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
|
||||
|
||||
GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
|
||||
GGML_ASSERT(v_to_float && "fattn: unsupported V-type");
|
||||
|
||||
// loop over n_batch and n_head
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// q indices
|
||||
|
||||
@@ -10,8 +10,6 @@
|
||||
|
||||
#if defined(GGML_USE_KOMPUTE)
|
||||
# include "ggml-kompute.h"
|
||||
#elif defined(GGML_USE_CANN)
|
||||
# include "ggml-cann.h"
|
||||
#endif
|
||||
|
||||
#ifndef __AMX_INT8__
|
||||
@@ -3399,10 +3397,6 @@ static int llama_get_device_count(const llama_model & model) {
|
||||
count += (int) model.rpc_servers.size();
|
||||
#endif
|
||||
|
||||
#if defined(GGML_USE_CANN)
|
||||
count += ggml_backend_cann_get_device_count();
|
||||
#endif
|
||||
|
||||
return count;
|
||||
|
||||
GGML_UNUSED(model);
|
||||
@@ -3420,11 +3414,7 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(const llama_mode
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(GGML_USE_CANN)
|
||||
if (host_buffer) {
|
||||
buft = ggml_backend_cann_host_buffer_type();
|
||||
}
|
||||
#elif defined(GGML_USE_CPU_HBM)
|
||||
#if defined(GGML_USE_CPU_HBM)
|
||||
buft = ggml_backend_cpu_hbm_buffer_type();
|
||||
#endif
|
||||
|
||||
@@ -3446,8 +3436,6 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_
|
||||
|
||||
#if defined(GGML_USE_KOMPUTE)
|
||||
buft = ggml_backend_kompute_buffer_type(device);
|
||||
#elif defined(GGML_USE_CANN)
|
||||
buft = ggml_backend_cann_buffer_type(device);
|
||||
#endif
|
||||
|
||||
if (buft == nullptr) {
|
||||
@@ -3491,14 +3479,13 @@ static size_t llama_get_device_memory(const llama_model & model, int device) {
|
||||
return free;
|
||||
}
|
||||
|
||||
#if defined(GGML_USE_CANN)
|
||||
size_t total;
|
||||
size_t free;
|
||||
ggml_backend_cann_get_device_memory(device, &free, &total);
|
||||
return free;
|
||||
#else
|
||||
if (model.devices.size() > 0) {
|
||||
ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(model.devices[0]);
|
||||
LLAMA_LOG_WARN("%s: failed to get free memmory of device:%d of backend:%s, for device id is out of range.\n", __func__, device, ggml_backend_reg_name(reg));
|
||||
} else {
|
||||
LLAMA_LOG_WARN("%s: failed to get free memmory of device, no devices in inputted model.\n", __func__);
|
||||
}
|
||||
return 1;
|
||||
#endif
|
||||
|
||||
GGML_UNUSED(model);
|
||||
GGML_UNUSED(device);
|
||||
@@ -19243,7 +19230,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
params.flash_attn = false;
|
||||
}
|
||||
|
||||
if (params.type_v != GGML_TYPE_F16 && !params.flash_attn) {
|
||||
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
|
||||
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -19396,30 +19383,6 @@ struct llama_context * llama_new_context_with_model(
|
||||
}
|
||||
ctx->backends.push_back(backend);
|
||||
}
|
||||
#elif defined(GGML_USE_CANN)
|
||||
// with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
|
||||
// TODO: ggml_backend_cann is not support split tensor now, just leave code here.
|
||||
if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
|
||||
ggml_backend_t backend = ggml_backend_cann_init(main_gpu);
|
||||
if (backend == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, main_gpu);
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
ctx->backends.push_back(backend);
|
||||
} else {
|
||||
// LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
|
||||
// TODO: currently, CANN can't use multi-gpus, just leave code here for further cann version.
|
||||
for (int32_t device = 0; device < ggml_backend_cann_get_device_count(); ++device) {
|
||||
ggml_backend_t backend = ggml_backend_cann_init(device);
|
||||
if (backend == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, device);
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
ctx->backends.push_back(backend);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// add other backends (such as BLAS)
|
||||
@@ -21734,6 +21697,16 @@ static int32_t llama_chat_apply_template_internal(
|
||||
if (add_ass) {
|
||||
ss << "[|assistant|]";
|
||||
}
|
||||
} else if (tmpl == "rwkv-world" || tmpl_contains("rwkv-world")) {
|
||||
// this template requires the model to have "\n\n" as EOT token
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "user") {
|
||||
ss << "User: " << message->content << "\n\nAssistant:";
|
||||
} else {
|
||||
ss << message->content << "\n\n";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
|
||||
Reference in New Issue
Block a user