Compare commits

...

15 Commits

Author SHA1 Message Date
Meng, Hengyu
c5d8bb5a81 leave only basic functions for SYCL CI
Some checks failed
flake8 Lint / Lint (push) Has been cancelled
2024-11-06 07:47:50 +00:00
Meng, Hengyu
c263ca767b remove wrong assert in norm
WA for permute(0,1,3,2) mul_mat
ggml-ci
2024-10-25 08:05:21 +00:00
Xuan Son Nguyen
958367bf53 server : refactor slot input data, move tokenizer to HTTP thread (#10023)
Some checks are pending
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-musa.Dockerfile platforms:linux/amd64 tag:full-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-musa.Dockerfile platforms:linux/amd64 tag:light-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-musa.Dockerfile platforms:linux/amd64 tag:server-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
Python Type-Check / pyright type-check (push) Waiting to run
* server : refactor slot input data, move tokenizer to HTTP thread

* move prompt_tokens.empty() check

* fix incorrect if branch

* fix infinite generation loop

* bring back infill validation

* add infill test

* try fixing format_infill

* fix test

* remove redundant code

* rename completion to inference

* update docs

* use llama_tokens everywhere
2024-10-24 21:51:22 +02:00
Georgi Gerganov
40f2555797 ci : fix cmake flags for SYCL 2024-10-24 21:23:33 +03:00
Johannes Gäßler
167a515651 CUDA: fix insufficient buffer clearing for MMQ (#10032)
Some checks are pending
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-musa.Dockerfile platforms:linux/amd64 tag:full-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-musa.Dockerfile platforms:linux/amd64 tag:light-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-musa.Dockerfile platforms:linux/amd64 tag:server-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
2024-10-24 14:40:23 +02:00
Johannes Gäßler
c39665f589 CUDA: fix MMQ for non-contiguous src0, add tests (#10021)
* CUDA: fix MMQ for non-contiguous src0, add tests

* revise test code
2024-10-24 11:09:36 +02:00
wwoodsTM
0a1c750c80 server : samplers accept the prompt correctly (#10019)
Some checks are pending
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-musa.Dockerfile platforms:linux/amd64 tag:full-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-musa.Dockerfile platforms:linux/amd64 tag:light-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-musa.Dockerfile platforms:linux/amd64 tag:server-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
2024-10-23 22:27:51 +03:00
Georgi Gerganov
190a37d797 sync : ggml
Some checks are pending
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-musa.Dockerfile platforms:linux/amd64 tag:full-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-musa.Dockerfile platforms:linux/amd64 tag:light-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-musa.Dockerfile platforms:linux/amd64 tag:server-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
2024-10-23 17:23:55 +03:00
Georgi Gerganov
2d3aba9ee8 llama.vim : bump generation time limit to 3s [no ci] 2024-10-23 17:16:56 +03:00
Johannes Gäßler
80273a306d CUDA: fix 1D im2col, add tests (ggml/993) 2024-10-23 16:50:02 +03:00
Daniel Bevenius
c19af0acb1 ggml : remove redundant set of contexts used field (ggml/978)
This commit removes the setting of the `used` field of the contexts in
the global state (g_state) in `ggml_init`.

The motivation for this change is that I believe that this additional
initialization might not be required after the changes in Commit
45fc4fed0b9fb5b1af4a8525cbebb95e11208732 ("sync : latest changes from
whisper.cpp"), which changed the initialization of the contexts field
from `{ 0 }` to `{ { 0 } }`:

```console
             g_state = (struct ggml_state) {
-                /*.contexts =*/ { 0 },
+                /*.contexts =*/ { { 0 } },
             };
```
My understanding is that the `{0}` initialization might not have
zero-initialized all the nested fields in every array element because of
compiler differences, and might have been the reason for having the
explicit setting of the `used` fields to false.
2024-10-23 16:50:02 +03:00
Michael Coppola
ac113a0fee llama.vim : add classic vim support (#9995)
* added classic vim support

* fixed ring update, removed blank line

* minor

* minor

* minor doc update

* removed uneeded var

* minor

* minor

* fixed job_start creating new scratch buffers

* fixed job_start creating new scratch buffers

* fixed ghost text indenting when expandtab is on

* removed unused code

* minor

* unified fim_on_exit

* minor

* vim ghost text rendering now uses pos_x and pos_y parameters

* renamed *_hlgroup to hlgroup_*

* renamed *_ghost_text to ghost_text_*, moved nvim/vim detection to llama#init()

* minor

---------

Co-authored-by: Michael Coppola <info@michaeljcoppola.com>
2024-10-23 14:09:26 +03:00
Jun Hee Yoo
4c9388fb96 metal : add POOL2D and fix IM2COL (#9943)
* add pool_2d

Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>

* fix im2col and add unittest for N>=1024

Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>

* add tests for N % 1024 != 0

Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>

* remove trailing whitespaces

Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>

* apply suggestions

Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>

* apply more optimization

- original IM2COL kernel + _ext with MIN()

Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>

* apply review: change kernel name of pool_2d

Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>

* apply review

Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>

* fix more formatting and enhance readability

Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>

---------

Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>
2024-10-23 13:33:45 +03:00
github-actions[bot]
873279b159 flake.lock: Update
Some checks failed
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
Nix aarch64 builds / nix-build-aarch64 (push) Has been cancelled
Flake lock file updates:

• Updated input 'nixpkgs':
    'github:NixOS/nixpkgs/5633bcff0c6162b9e4b5f1264264611e950c8ec7?narHash=sha256-9UTxR8eukdg%2BXZeHgxW5hQA9fIKHsKCdOIUycTryeVw%3D' (2024-10-09)
  → 'github:NixOS/nixpkgs/4c2fcb090b1f3e5b47eaa7bd33913b574a11e0a0?narHash=sha256-/uilDXvCIEs3C9l73JTACm4quuHUsIHcns1c%2BcHUJwA%3D' (2024-10-18)
2024-10-23 01:28:07 +00:00
Xuan Son Nguyen
c8c07d658a llama : fix empty batch causing llama_batch_allocr to crash (#9966)
Some checks failed
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-musa.Dockerfile platforms:linux/amd64 tag:full-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-musa.Dockerfile platforms:linux/amd64 tag:light-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-musa.Dockerfile platforms:linux/amd64 tag:server-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
Python check requirements.txt / check-requirements (push) Has been cancelled
Python Type-Check / pyright type-check (push) Has been cancelled
* llama : fix empty batch cause llama_batch_allocr to crash

* move batch_allocr inside decode/encode_internal

* fix build

* add GGML_ASSERT

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-10-22 16:59:02 +02:00
19 changed files with 1099 additions and 531 deletions

View File

@@ -53,7 +53,9 @@ if [ ! -z ${GG_BUILD_SYCL} ]; then
exit 1
fi
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_SYCL=1 DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON"
# Only functionality CI for SYCL now
GG_BUILD_LOW_PERF=True
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_SYCL=1 -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON"
fi
if [ ! -z ${GG_BUILD_VULKAN} ]; then
@@ -150,7 +152,12 @@ function gg_run_ctest_release {
if [ -z ${GG_BUILD_LOW_PERF} ]; then
(time ctest --output-on-failure -L main ) 2>&1 | tee -a $OUT/${ci}-ctest.log
else
(time ctest --output-on-failure -L main -E test-opt ) 2>&1 | tee -a $OUT/${ci}-ctest.log
if [ ! -z "$GG_BUILD_SYCL" ]; then
# TODO(airMeng): fix iq1_xs and iq3_xs quantization in SYCL
(time ctest --output-on-failure -L main -E "test-quantize-fns|test-opt" ) 2>&1 | tee -a "$OUT/${ci}-ctest.log"
else
(time ctest --output-on-failure -L main -E test-opt ) 2>&1 | tee -a $OUT/${ci}-ctest.log
fi
fi
set +e
@@ -824,7 +831,10 @@ fi
ret=0
test $ret -eq 0 && gg_run ctest_debug
if [ -z "$GG_BUILD_SYCL" ]; then
# to save time, remove after more machines available
test $ret -eq 0 && gg_run ctest_debug
fi
test $ret -eq 0 && gg_run ctest_release
if [ -z ${GG_BUILD_LOW_PERF} ]; then

View File

@@ -2,7 +2,7 @@
"
" requires:
"
" - neovim
" - neovim or vim
" - curl
" - llama.cpp server instance
" - FIM-compatible model
@@ -10,7 +10,7 @@
" sample config:
"
" - Tab - accept the current suggestion
" - Shift+Tab - accept just the first line of the segguestion
" - Shift+Tab - accept just the first line of the suggestion
" - Ctrl+F - toggle FIM completion manually
"
" make symlink or copy this file to ~/.config/nvim/autoload/llama.vim
@@ -43,8 +43,8 @@
"
" colors (adjust to your liking)
highlight llama_hl_hint guifg=#ff772f
highlight llama_hl_info guifg=#77ff2f
highlight llama_hl_hint guifg=#ff772f ctermfg=202
highlight llama_hl_info guifg=#77ff2f ctermfg=119
" general parameters:
"
@@ -81,7 +81,7 @@ let s:default_config = {
\ 'n_suffix': 64,
\ 'n_predict': 128,
\ 't_max_prompt_ms': 500,
\ 't_max_predict_ms': 1000,
\ 't_max_predict_ms': 3000,
\ 'show_info': 2,
\ 'auto_fim': v:true,
\ 'max_line_suffix': 8,
@@ -93,6 +93,18 @@ let s:default_config = {
let g:llama_config = get(g:, 'llama_config', s:default_config)
function! s:get_indent(str)
let l:count = 0
for i in range(len(a:str))
if a:str[i] == "\t"
let l:count += &tabstop - 1
else
break
endif
endfor
return l:count
endfunction
function! s:rand(i0, i1) abort
return a:i0 + rand() % (a:i1 - a:i0 + 1)
endfunction
@@ -129,6 +141,21 @@ function! llama#init()
let s:current_job = v:null
let s:ghost_text_nvim = exists('*nvim_buf_get_mark')
let s:ghost_text_vim = has('textprop')
if s:ghost_text_vim
let s:hlgroup_hint = 'llama_hl_hint'
let s:hlgroup_info = 'llama_hl_info'
if empty(prop_type_get(s:hlgroup_hint))
call prop_type_add(s:hlgroup_hint, {'highlight': s:hlgroup_hint})
endif
if empty(prop_type_get(s:hlgroup_info))
call prop_type_add(s:hlgroup_info, {'highlight': s:hlgroup_info})
endif
endif
augroup llama
autocmd!
autocmd InsertEnter * inoremap <expr> <silent> <C-F> llama#fim_inline(v:false)
@@ -317,13 +344,22 @@ function! s:ring_update()
\ 't_max_predict_ms': 1
\ })
let l:curl_command = printf(
\ "curl --silent --no-buffer --request POST --url %s --header \"Content-Type: application/json\" --data %s",
\ g:llama_config.endpoint, shellescape(l:request)
\ )
let l:curl_command = [
\ "curl",
\ "--silent",
\ "--no-buffer",
\ "--request", "POST",
\ "--url", g:llama_config.endpoint,
\ "--header", "Content-Type: application/json",
\ "--data", l:request
\ ]
" no callbacks because we don't need to process the response
call jobstart(l:curl_command, {})
if s:ghost_text_nvim
call jobstart(l:curl_command, {})
elseif s:ghost_text_vim
call job_start(l:curl_command, {})
endif
endfunction
" necessary for 'inoremap <expr>'
@@ -418,24 +454,37 @@ function! llama#fim(is_auto) abort
\ 't_max_predict_ms': g:llama_config.t_max_predict_ms
\ })
let l:curl_command = printf(
\ "curl --silent --no-buffer --request POST --url %s --header \"Content-Type: application/json\" --data %s",
\ g:llama_config.endpoint, shellescape(l:request)
\ )
let l:curl_command = [
\ "curl",
\ "--silent",
\ "--no-buffer",
\ "--request", "POST",
\ "--url", g:llama_config.endpoint,
\ "--header", "Content-Type: application/json",
\ "--data", l:request
\ ]
if s:current_job != v:null
call jobstop(s:current_job)
if s:ghost_text_nvim
call jobstop(s:current_job)
elseif s:ghost_text_vim
call job_stop(s:current_job)
endif
endif
" send the request asynchronously
let s:current_job = jobstart(l:curl_command, {
\ 'on_stdout': function('s:fim_on_stdout'),
\ 'on_exit': function('s:fim_on_exit'),
\ 'stdout_buffered': v:true,
\ 'pos_x': s:pos_x,
\ 'pos_y': s:pos_y,
\ 'is_auto': a:is_auto
\ })
if s:ghost_text_nvim
let s:current_job = jobstart(l:curl_command, {
\ 'on_stdout': function('s:fim_on_stdout', [s:pos_x, s:pos_y, a:is_auto]),
\ 'on_exit': function('s:fim_on_exit'),
\ 'stdout_buffered': v:true
\ })
elseif s:ghost_text_vim
let s:current_job = job_start(l:curl_command, {
\ 'out_cb': function('s:fim_on_stdout', [s:pos_x, s:pos_y, a:is_auto]),
\ 'exit_cb': function('s:fim_on_exit')
\ })
endif
" TODO: per-file location
let l:delta_y = abs(s:pos_y - s:pos_y_pick)
@@ -482,9 +531,13 @@ function! llama#fim_cancel()
" clear the virtual text
let l:bufnr = bufnr('%')
let l:id_vt_fim = nvim_create_namespace('vt_fim')
call nvim_buf_clear_namespace(l:bufnr, l:id_vt_fim, 0, -1)
if s:ghost_text_nvim
let l:id_vt_fim = nvim_create_namespace('vt_fim')
call nvim_buf_clear_namespace(l:bufnr, l:id_vt_fim, 0, -1)
elseif s:ghost_text_vim
call prop_remove({'type': s:hlgroup_hint, 'all': v:true})
call prop_remove({'type': s:hlgroup_info, 'all': v:true})
endif
" remove the mappings
silent! iunmap <buffer> <Tab>
@@ -499,13 +552,18 @@ function! s:on_move()
endfunction
" callback that processes the FIM result from the server and displays the suggestion
function! s:fim_on_stdout(job_id, data, event) dict
let l:raw = join(a:data, "\n")
function! s:fim_on_stdout(pos_x, pos_y, is_auto, job_id, data, event = v:null)
if s:ghost_text_nvim
let l:raw = join(a:data, "\n")
elseif s:ghost_text_vim
let l:raw = a:data
endif
if len(l:raw) == 0
return
endif
if self.pos_x != col('.') - 1 || self.pos_y != line('.')
if a:pos_x != col('.') - 1 || a:pos_y != line('.')
return
endif
@@ -514,14 +572,14 @@ function! s:fim_on_stdout(job_id, data, event) dict
return
endif
let s:pos_x = self.pos_x
let s:pos_y = self.pos_y
let s:pos_x = a:pos_x
let s:pos_y = a:pos_y
let s:can_accept = v:true
let l:has_info = v:false
if s:can_accept && v:shell_error
if !self.is_auto
if !a:is_auto
call add(s:content, "<| curl error: is the server on? |>")
endif
let s:can_accept = v:false
@@ -642,7 +700,9 @@ function! s:fim_on_stdout(job_id, data, event) dict
" display virtual text with the suggestion
let l:bufnr = bufnr('%')
let l:id_vt_fim = nvim_create_namespace('vt_fim')
if s:ghost_text_nvim
let l:id_vt_fim = nvim_create_namespace('vt_fim')
endif
" construct the info message
if g:llama_config.show_info > 0 && l:has_info
@@ -671,15 +731,41 @@ function! s:fim_on_stdout(job_id, data, event) dict
endif
" display the suggestion and append the info to the end of the first line
call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, s:pos_x - 1, {
\ 'virt_text': [[s:content[0], 'llama_hl_hint'], [l:info, 'llama_hl_info']],
\ 'virt_text_win_col': virtcol('.') - 1
\ })
if s:ghost_text_nvim
call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, s:pos_x - 1, {
\ 'virt_text': [[s:content[0], 'llama_hl_hint'], [l:info, 'llama_hl_info']],
\ 'virt_text_win_col': virtcol('.') - 1
\ })
call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, 0, {
\ 'virt_lines': map(s:content[1:], {idx, val -> [[val, 'llama_hl_hint']]}),
\ 'virt_text_win_col': virtcol('.')
\ })
call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, 0, {
\ 'virt_lines': map(s:content[1:], {idx, val -> [[val, 'llama_hl_hint']]}),
\ 'virt_text_win_col': virtcol('.')
\ })
elseif s:ghost_text_vim
let l:new_suffix = s:content[0]
if !empty(l:new_suffix)
call prop_add(s:pos_y, s:pos_x + 1, {
\ 'type': s:hlgroup_hint,
\ 'text': l:new_suffix
\ })
endif
for line in s:content[1:]
call prop_add(s:pos_y, 0, {
\ 'type': s:hlgroup_hint,
\ 'text': line,
\ 'text_padding_left': s:get_indent(line),
\ 'text_align': 'below'
\ })
endfor
if !empty(l:info)
call prop_add(s:pos_y, 0, {
\ 'type': s:hlgroup_info,
\ 'text': l:info,
\ 'text_padding_left': col('$'),
\ 'text_wrap': 'truncate'
\ })
endif
endif
" setup accept shortcuts
inoremap <buffer> <Tab> <C-O>:call llama#fim_accept(v:false)<CR>
@@ -688,7 +774,7 @@ function! s:fim_on_stdout(job_id, data, event) dict
let s:hint_shown = v:true
endfunction
function! s:fim_on_exit(job_id, exit_code, event) dict
function! s:fim_on_exit(job_id, exit_code, event = v:null)
if a:exit_code != 0
echom "Job failed with exit code: " . a:exit_code
endif

View File

@@ -319,6 +319,18 @@ node index.js
- The prompt is a string or an array with the first element given as a string
- The model's `tokenizer.ggml.add_bos_token` metadata is `true`
These input shapes and data type are allowed for `prompt`:
- Single string: `"string"`
- Single sequence of tokens: `[12, 34, 56]`
- Mixed tokens and strings: `[12, 34, "string", 56, 78]`
Multiple prompts are also supported. In this case, the completion result will be an array.
- Only strings: `["string1", "string2"]`
- Strings and sequences of tokens: `["string1", [12, 34, 56]]`
- Mixed types: `[[12, 34, "string", 56, 78], [12, 34, 56], "string"]`
`temperature`: Adjust the randomness of the generated text. Default: `0.8`
`dynatemp_range`: Dynamic temperature range. The final temperature will be in the range of `[temperature - dynatemp_range; temperature + dynatemp_range]` Default: `0.0`, which is disabled.

View File

@@ -43,21 +43,6 @@
#include <unordered_map>
#include <unordered_set>
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
using json = nlohmann::ordered_json;
enum stop_type {
@@ -68,6 +53,7 @@ enum stop_type {
// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
enum slot_state {
SLOT_STATE_IDLE,
SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
SLOT_STATE_PROCESSING_PROMPT,
SLOT_STATE_DONE_PROMPT,
SLOT_STATE_GENERATING,
@@ -79,7 +65,7 @@ enum server_state {
};
enum server_task_type {
SERVER_TASK_TYPE_COMPLETION,
SERVER_TASK_TYPE_INFERENCE,
SERVER_TASK_TYPE_CANCEL,
SERVER_TASK_TYPE_NEXT_RESPONSE,
SERVER_TASK_TYPE_METRICS,
@@ -89,21 +75,22 @@ enum server_task_type {
SERVER_TASK_TYPE_SET_LORA,
};
enum server_task_cmpl_type {
SERVER_TASK_CMPL_TYPE_NORMAL,
SERVER_TASK_CMPL_TYPE_EMBEDDING,
SERVER_TASK_CMPL_TYPE_RERANK,
SERVER_TASK_CMPL_TYPE_INFILL,
enum server_task_inf_type {
SERVER_TASK_INF_TYPE_COMPLETION,
SERVER_TASK_INF_TYPE_EMBEDDING,
SERVER_TASK_INF_TYPE_RERANK,
SERVER_TASK_INF_TYPE_INFILL,
};
struct server_task {
int id = -1; // to be filled by server_queue
int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
llama_tokens prompt_tokens;
server_task_type type;
json data;
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
// utility function
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
@@ -161,26 +148,20 @@ struct server_slot {
int32_t i_batch = -1;
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
// n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
int32_t n_prompt_tokens = 0;
int32_t n_prompt_tokens_processed = 0;
json prompt; // can be either a string, array of strings or array of token ids
json input_prefix;
json input_suffix;
json input_extra;
// when a task is submitted, we first tokenize the prompt and store it here
std::vector<llama_token> prompt_tokens;
std::vector<llama_token> extra_tokens;
// input prompt tokens
llama_tokens prompt_tokens;
size_t last_nl_pos = 0;
std::string generated_text;
std::vector<llama_token> cache_tokens;
llama_tokens cache_tokens;
std::vector<completion_token_output> generated_token_probs;
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
bool has_next_token = true;
bool has_new_line = false;
@@ -229,7 +210,7 @@ struct server_slot {
n_past = 0;
n_sent_text = 0;
n_sent_token_probs = 0;
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
generated_token_probs.clear();
}
@@ -734,42 +715,6 @@ struct server_context {
metrics.init();
}
std::vector<llama_token> tokenize(const json & json_prompt, bool add_special, bool parse_special) const {
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string.
std::vector<llama_token> prompt_tokens;
if (json_prompt.is_array()) {
bool first = true;
for (const auto & p : json_prompt) {
if (p.is_string()) {
auto s = p.template get<std::string>();
std::vector<llama_token> p;
if (first) {
p = common_tokenize(ctx, s, add_special, parse_special);
first = false;
} else {
p = common_tokenize(ctx, s, false, parse_special);
}
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
} else {
if (first) {
first = false;
}
prompt_tokens.push_back(p.template get<llama_token>());
}
}
} else {
auto s = json_prompt.template get<std::string>();
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
}
return prompt_tokens;
}
server_slot * get_slot_by_id(int id) {
for (server_slot & slot : slots) {
if (slot.id == id) {
@@ -794,22 +739,16 @@ struct server_context {
continue;
}
// skip the slot if it does not contains prompt
if (!slot.prompt.is_string()) {
// skip the slot if it does not contains cached tokens
if (slot.prompt_tokens.empty()) {
continue;
}
// current slot's prompt
std::string slot_prompt = slot.prompt.get<std::string>();
// length of the current slot's prompt
int slot_prompt_len = slot_prompt.size();
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
int lcp_len = longest_common_prefix(slot_prompt, prompt);
int lcp_len = longest_common_prefix(slot.cache_tokens, slot.prompt_tokens);
// fraction of the common substring length compared to the current slot's prompt length
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
similarity = static_cast<float>(lcp_len) / static_cast<int>(slot.prompt_tokens.size());
// select the current slot if the criteria match
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
@@ -914,57 +853,6 @@ struct server_context {
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
}
// infill
slot.input_prefix = json_value(data, "input_prefix", json());
slot.input_suffix = json_value(data, "input_suffix", json());
slot.input_extra = json_value(data, "input_extra", json());
SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.input_extra.size());
for (const auto & chunk : slot.input_extra) {
// { "text": string, "filename": string }
if (!chunk.contains("text") || !chunk["text"].is_string()) {
send_error(task, "extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST);
return false;
}
// filename is optional
if (chunk.contains("filename") && !chunk["filename"].is_string()) {
send_error(task, "extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST);
return false;
}
SLT_DBG(slot, "extra_context chunk in file '%s':\n%s\n", chunk.value("filename", "").c_str(), chunk.value("text", "").c_str());
}
// get prompt
{
const auto & prompt = data.find("prompt");
if (prompt == data.end()) {
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
return false;
}
if ((prompt->is_string()) ||
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
slot.prompt = *prompt;
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
slot.prompt = prompt->at(0);
} else if (prompt->is_array() && prompt->size() > 1) {
// array of strings
for (const auto & el : *prompt) {
if (!el.is_string()) {
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
return false;
}
}
slot.prompt = *prompt;
} else {
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
return false;
}
}
{
slot.sparams.logit_bias.clear();
@@ -1044,8 +932,7 @@ struct server_context {
}
}
slot.state = SLOT_STATE_PROCESSING_PROMPT;
slot.prompt_tokens.clear();
slot.state = SLOT_STATE_STARTED;
SLT_INF(slot, "%s", "processing task\n");
@@ -1297,7 +1184,7 @@ struct server_context {
};
if (slot.sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
@@ -1333,7 +1220,7 @@ struct server_context {
{"tokens_predicted", slot.n_decoded},
{"tokens_evaluated", slot.n_prompt_tokens},
{"generation_settings", get_formated_generation(slot)},
{"prompt", slot.prompt},
{"prompt", common_detokenize(ctx, slot.prompt_tokens)},
{"has_new_line", slot.has_new_line},
{"truncated", slot.truncated},
{"stopped_eos", slot.stopped_eos},
@@ -1348,7 +1235,7 @@ struct server_context {
if (slot.sparams.n_probs > 0) {
std::vector<completion_token_output> probs;
if (!slot.params.stream && slot.stopped_word) {
const std::vector<llama_token> stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
probs = std::vector<completion_token_output>(
@@ -1457,19 +1344,17 @@ struct server_context {
// Functions to create new task(s) and receive result(s)
//
std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
// break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
std::vector<server_task> tasks;
auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
server_task task;
task.id = queue_tasks.get_new_id();
task.cmpl_type = cmpl_type;
task.type = SERVER_TASK_TYPE_COMPLETION;
if (replace_prompt) {
task.data = task_data;
task.data["prompt"] = std::move(prompt);
} else {
task.data = std::move(task_data);
}
task.id = queue_tasks.get_new_id();
task.inf_type = inf_type;
task.type = SERVER_TASK_TYPE_INFERENCE;
task.data = task_data;
task.prompt_tokens = std::move(prompt_tokens);
tasks.push_back(std::move(task));
};
@@ -1478,41 +1363,49 @@ struct server_context {
throw std::runtime_error(error_msg);
}
json prompt = data.at("prompt");
// if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task
if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
data["index"] = 0;
create_task(data, false, nullptr);
} else if (prompt.is_array()) {
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
std::vector<json> prompts = prompt;
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
// prompts[0] is the question
// the rest are the answers/documents
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
for (size_t i = 1; i < prompts.size(); i++) {
json qd;
qd.push_back(prompts[0]);
qd.push_back(prompts[i]);
data["index"] = i - 1;
create_task(data, true, qd);
}
} else {
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
for (size_t i = 0; i < prompts.size(); i++) {
const auto & e = prompts[i];
if (e.is_string() || json_is_array_of_numbers(e)) {
// because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
switch (inf_type) {
case SERVER_TASK_INF_TYPE_RERANK:
{
// prompts[0] is the question
// the rest are the answers/documents
GGML_ASSERT(tokenized_prompts.size() > 1);
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1);
for (size_t i = 1; i < tokenized_prompts.size(); i++) {
data["index"] = i - 1;
auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]);
create_task(data, tokens);
}
} break;
case SERVER_TASK_INF_TYPE_INFILL:
{
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
data["index"] = i;
create_task(data, true, e);
} else {
throw std::runtime_error(error_msg);
auto tokens = format_infill(
ctx,
data.at("input_prefix"),
data.at("input_suffix"),
data.at("input_extra"),
params.n_batch,
params.n_predict,
slots[0].n_ctx, // TODO: there should be a better way
params.spm_infill,
tokenized_prompts[i]
);
create_task(data, tokens);
}
} break;
default:
{
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
data["index"] = i;
create_task(data, tokenized_prompts[i]);
}
}
}
} else {
// invalid case
throw std::runtime_error(error_msg);
}
return tasks;
@@ -1534,7 +1427,7 @@ struct server_context {
queue_tasks.post(cancel_tasks, true);
}
// receive the results from task(s) created by create_tasks_cmpl
// receive the results from task(s) created by create_tasks_inference
void receive_cmpl_results(
const std::unordered_set<int> & id_tasks,
const std::function<void(std::vector<server_task_result>&)> & result_handler,
@@ -1558,7 +1451,7 @@ struct server_context {
result_handler(results);
}
// receive the results from task(s) created by create_tasks_cmpl, in stream mode
// receive the results from task(s) created by create_tasks_inference, in stream mode
void receive_cmpl_results_stream(
const std::unordered_set<int> & id_tasks, const
std::function<bool(server_task_result&)> & result_handler, const
@@ -1591,7 +1484,7 @@ struct server_context {
void process_single_task(const server_task & task) {
switch (task.type) {
case SERVER_TASK_TYPE_COMPLETION:
case SERVER_TASK_TYPE_INFERENCE:
{
const int id_slot = json_value(task.data, "id_slot", -1);
@@ -1623,9 +1516,10 @@ struct server_context {
slot->reset();
slot->id_task = task.id;
slot->cmpl_type = task.cmpl_type;
slot->index = json_value(task.data, "index", 0);
slot->id_task = task.id;
slot->inf_type = task.inf_type;
slot->index = json_value(task.data, "index", 0);
slot->prompt_tokens = std::move(task.prompt_tokens);
if (!launch_slot_with_task(*slot, task)) {
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
@@ -1658,7 +1552,7 @@ struct server_context {
slot_data["id"] = slot.id;
slot_data["id_task"] = slot.id_task;
slot_data["state"] = slot.state;
slot_data["prompt"] = slot.prompt;
slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens);
slot_data["next_token"] = {
{"has_next_token", slot.has_next_token},
{"has_new_line", slot.has_new_line},
@@ -1785,9 +1679,6 @@ struct server_context {
}
slot->cache_tokens.resize(token_count);
// TODO: maybe detokenize the slot->cache_tokens instead?
slot->prompt = string_format("[restored %d tokens from file]", (int) token_count);
const int64_t t_end = ggml_time_us();
const double t_restore_ms = (t_end - t_start) / 1000.0;
@@ -1954,142 +1845,18 @@ struct server_context {
if (params.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) {
// this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT) {
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
auto & prompt_tokens = slot.prompt_tokens;
// we haven't tokenized the prompt yet - do it now:
if (prompt_tokens.empty()) {
SLT_INF(slot, "tokenizing prompt, len = %d\n", (int) slot.prompt.size());
// TODO: maybe move branch to outside of this loop in the future
if (slot.state == SLOT_STATE_STARTED) {
slot.t_start_process_prompt = ggml_time_us();
slot.t_start_generation = 0;
switch (slot.cmpl_type) {
case SERVER_TASK_CMPL_TYPE_NORMAL:
case SERVER_TASK_CMPL_TYPE_EMBEDDING:
{
prompt_tokens = tokenize(slot.prompt, llama_add_bos_token(model), true);
} break;
case SERVER_TASK_CMPL_TYPE_RERANK:
{
// require slot.prompt to be array of 2 strings
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
slot.release();
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
continue;
}
// prompt: [BOS]query[EOS][SEP]doc[EOS]
prompt_tokens.clear();
prompt_tokens.push_back(llama_token_bos(model));
{
const auto part = tokenize(slot.prompt[0], false, false);
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
}
prompt_tokens.push_back(llama_token_eos(model));
prompt_tokens.push_back(llama_token_sep(model));
{
const auto part = tokenize(slot.prompt[1], false, false);
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
}
prompt_tokens.push_back(llama_token_eos(model));
} break;
case SERVER_TASK_CMPL_TYPE_INFILL:
{
// TODO: optimize this block by reducing memory allocations and movement
// use FIM repo-level pattern:
// ref: https://arxiv.org/pdf/2409.12186
//
// [FIM_REP]myproject
// [FIM_SEP]filename0
// extra chunk 0
// [FIM_SEP]filename1
// extra chunk 1
// ...
// [FIM_SEP]filename
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
//
auto tokens_prefix = tokenize(slot.input_prefix, false, false);
auto tokens_suffix = tokenize(slot.input_suffix, false, false);
auto tokens_prompt = tokenize(slot.prompt, false, false);
slot.extra_tokens.clear();
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
static const auto k_fim_repo = tokenize("myproject\n", false, false);
slot.extra_tokens.push_back(llama_token_fim_rep(model));
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
}
for (const auto & chunk : slot.input_extra) {
// { "text": string, "filename": string }
const std::string text = chunk.value("text", "");
const std::string filename = chunk.value("filename", "tmp");
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
const auto k_fim_file = tokenize(filename + "\n", false, false);
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
} else {
// chunk separator in binary form to avoid confusing the AI
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
static const auto k_chunk_prefix_tokens = tokenize(k_chunk_prefix_str, false, false);
slot.extra_tokens.insert(slot.extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
}
const auto chunk_tokens = tokenize(text, false, false);
slot.extra_tokens.insert(slot.extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
}
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
// TODO: current filename
static const auto k_fim_file = tokenize("filename\n", false, false);
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
}
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
const int n_suffix_take = std::min<int>(tokens_suffix.size(), (n_batch/4));
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4) - 3);
// fill the rest of the context with extra chunks
const int n_extra_take = std::min<int>(std::max<int>(0, slot.n_ctx - (n_batch) - 2*slot.n_predict), slot.extra_tokens.size());
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
tokens_suffix.resize(n_suffix_take);
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
auto embd_inp = params.spm_infill ? tokens_suffix : tokens_prefix;
auto embd_end = params.spm_infill ? tokens_prefix : tokens_suffix;
if (llama_add_bos_token(model)) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
}
SLT_DBG(slot, "extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", slot.n_ctx, n_extra_take, (int) slot.extra_tokens.size());
// put the extra context before the FIM prefix
embd_inp.insert(embd_inp.begin(), slot.extra_tokens.end() - n_extra_take, slot.extra_tokens.end());
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
embd_inp.push_back(llama_token_fim_mid(model));
prompt_tokens = std::move(embd_inp);
} break;
}
slot.n_past = 0;
slot.n_prompt_tokens = prompt_tokens.size();
slot.state = SLOT_STATE_PROCESSING_PROMPT;
SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
// print prompt tokens (for debugging)
if (1) {
@@ -2114,7 +1881,7 @@ struct server_context {
continue;
}
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
// this prompt is too large to process - discard it
if (slot.n_prompt_tokens > n_ubatch) {
slot.release();
@@ -2144,7 +1911,7 @@ struct server_context {
const int n_block_size = n_left / 2;
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
std::vector<llama_token> new_tokens(
llama_tokens new_tokens(
prompt_tokens.begin(),
prompt_tokens.begin() + slot.params.n_keep);
@@ -2163,17 +1930,10 @@ struct server_context {
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
}
common_sampler_reset(slot.smpl);
if (slot.params.cache_prompt) {
// reuse any previously computed tokens that are common with the new prompt
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
common_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
}
// reuse chunks from the cached prompt by shifting their KV cache in the new position
if (params.n_cache_reuse > 0) {
size_t head_c = slot.n_past; // cache
@@ -2206,8 +1966,6 @@ struct server_context {
for (size_t i = 0; i < n_match; i++) {
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
common_sampler_accept(slot.smpl, slot.cache_tokens[head_p + i], false);
slot.n_past++;
}
@@ -2234,7 +1992,7 @@ struct server_context {
}
// non-causal tasks require to fit the entire prompt in the physical batch
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
continue;
@@ -2243,8 +2001,8 @@ struct server_context {
// check that we are in the right batch_type, if not defer the slot
const bool slot_type =
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;
if (batch_type == -1) {
batch_type = slot_type;
@@ -2259,8 +2017,6 @@ struct server_context {
// there is no common part left
slot.n_past = 0;
common_sampler_reset(slot.smpl);
}
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
@@ -2288,6 +2044,13 @@ struct server_context {
GGML_ASSERT(batch.n_tokens > 0);
common_sampler_reset(slot.smpl);
// Process all prompt tokens through sampler system
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
common_sampler_accept(slot.smpl, prompt_tokens[i], false);
}
// extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true;
@@ -2357,7 +2120,7 @@ struct server_context {
}
if (slot.state == SLOT_STATE_DONE_PROMPT) {
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
// prompt evaluated for embedding
send_embedding(slot, batch_view);
slot.release();
@@ -2365,7 +2128,7 @@ struct server_context {
continue; // continue loop of slots
}
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
send_rerank(slot, batch_view);
slot.release();
slot.i_batch = -1;
@@ -2919,13 +2682,13 @@ int main(int argc, char ** argv) {
res_ok(res, {{ "success", true }});
};
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
if (ctx_server.params.embedding || ctx_server.params.reranking) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type);
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
@@ -2971,10 +2734,11 @@ int main(int argc, char ** argv) {
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body);
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
};
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
// check model compatibility
std::string err;
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
err += "prefix token is missing. ";
@@ -2985,14 +2749,42 @@ int main(int argc, char ** argv) {
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
err += "middle token is missing. ";
}
if (!err.empty()) {
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
return;
}
json data = json::parse(req.body);
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
// validate input
if (!data.contains("input_prefix")) {
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
}
if (!data.contains("input_suffix")) {
res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
}
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
return;
}
json input_extra = json_value(data, "input_extra", json::array());
for (const auto & chunk : input_extra) {
// { "text": string, "filename": string }
if (!chunk.contains("text") || !chunk.at("text").is_string()) {
res_error(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
return;
}
// filename is optional
if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
res_error(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
return;
}
}
data["input_extra"] = input_extra; // default to empty array if it's not exist
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
};
// TODO: maybe merge this function with "handle_completions_generic"
@@ -3004,7 +2796,7 @@ int main(int argc, char ** argv) {
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
@@ -3077,7 +2869,7 @@ int main(int argc, char ** argv) {
const bool add_special = json_value(body, "add_special", false);
const bool with_pieces = json_value(body, "with_pieces", false);
std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special, true);
llama_tokens tokens = tokenize_mixed(ctx_server.ctx, body.at("content"), add_special, true);
if (with_pieces) {
for (const auto& token : tokens) {
@@ -3114,7 +2906,7 @@ int main(int argc, char ** argv) {
std::string content;
if (body.count("tokens") != 0) {
const std::vector<llama_token> tokens = body.at("tokens");
const llama_tokens tokens = body.at("tokens");
content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
}
@@ -3148,7 +2940,7 @@ int main(int argc, char ** argv) {
json responses = json::array();
bool error = false;
{
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
@@ -3225,7 +3017,7 @@ int main(int argc, char ** argv) {
json responses = json::array();
bool error = false;
{
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);

View File

@@ -0,0 +1,36 @@
@llama.cpp
@infill
Feature: llama.cpp server
# The current model is made by adding FIM tokens to the existing stories260K
# We may want to use a better model in the future, maybe something like SmolLM 360M
Background: Server startup
Given a server listening on localhost:8080
And a model file tinyllamas/stories260K-infill.gguf from HF repo ggml-org/models
And a model file test-model-infill.gguf
And a model alias tinyllama-infill
And 42 as server seed
And 1024 as batch size
And 1024 as ubatch size
And 2048 KV cache size
And 64 max tokens to predict
And 0.0 temperature
Then the server is starting
Then the server is healthy
Scenario: Infill without input_extra
Given a prompt "Complete this"
And an infill input extra none none
And an infill input prefix "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_"
And an infill input suffix "}\n"
And an infill request with no api error
Then 64 tokens are predicted matching One|day|she|saw|big|scary|bird
Scenario: Infill with input_extra
Given a prompt "Complete this"
And an infill input extra "llama.h" "LLAMA_API int32_t llama_n_threads();\n"
And an infill input prefix "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_"
And an infill input suffix "}\n"
And an infill request with no api error
Then 64 tokens are predicted matching cuts|Jimmy|mom|came|into|the|room"

View File

@@ -80,6 +80,11 @@ def step_server_config(context, server_fqdn: str, server_port: str):
context.lora_file = None
context.disable_ctx_shift = False
# infill
context.infill_input_extra = None
context.infill_input_suffix = ''
context.infill_input_prefix = ''
context.tasks_result = []
context.concurrent_tasks = []
context.prompts = []
@@ -291,6 +296,28 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
@step('an infill request with {api_error} api error')
@async_run_until_complete
async def step_request_completion(context, api_error: Literal['raised'] | str):
if api_error != 'no':
raise ValueError(f'api_error={api_error} is not yet implemented')
payload = {
"prompt": context.prompts[0],
"input_suffix": context.infill_input_suffix,
"input_prefix": context.infill_input_prefix,
"n_predict": context.n_predict,
"seed": context.seed,
"temperature": context.temperature,
}
if context.infill_input_extra is not None:
payload['input_extra'] = context.infill_input_extra
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
async with session.post(f'{context.base_url}/infill',
json=payload) as response:
assert response.status == 200
context.tasks_result = [await response.json()]
@step('{predicted_n:d} tokens are predicted matching {re_content}')
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
context.completion = context.tasks_result.pop()
@@ -539,6 +566,25 @@ def step_a_prompt_prompt(context, prompt):
context.n_prompts = len(context.prompts)
# TODO: allow this to be repeated
@step('an infill input extra {filename} {text}')
def step_infill_input_extra(context, filename, text):
if filename == 'none':
context.infill_input_extra = None
else:
context.infill_input_extra = [{'filename': filename, 'text': text}]
@step('an infill input suffix {text}')
def step_infill_input_suffix(context, text):
context.infill_input_suffix = text
@step('an infill input prefix {text}')
def step_infill_input_prefix(context, text):
context.infill_input_prefix = text
@step('{num_prompts:d} prompts {prompt} with seed {seed:d}')
def step_many_prompts(context, num_prompts, prompt, seed):
if context.seed is None:

View File

@@ -24,6 +24,22 @@
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
using json = nlohmann::ordered_json;
using llama_tokens = std::vector<llama_token>;
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
enum error_type {
@@ -52,9 +68,235 @@ static T json_value(const json & body, const std::string & key, const T & defaul
}
//
// chat template utils
// tokenizer and input processing utils
//
static bool json_is_array_of_numbers(const json & data) {
if (data.is_array()) {
for (const auto & e : data) {
if (!e.is_number_integer()) {
return false;
}
}
return true;
}
return false;
}
// is array having BOTH numbers & strings?
static bool json_is_array_of_mixed_numbers_strings(const json & data) {
bool seen_string = false;
bool seen_number = false;
if (data.is_array()) {
for (const auto & e : data) {
seen_string |= e.is_string();
seen_number |= e.is_number_integer();
if (seen_number && seen_string) {
return true;
}
}
}
return false;
}
/**
* this handles 2 cases:
* - only string, example: "string"
* - mixed string and tokens, example: [12, 34, "string", 56, 78]
*/
static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string.
llama_tokens prompt_tokens;
if (json_prompt.is_array()) {
bool first = true;
for (const auto & p : json_prompt) {
if (p.is_string()) {
auto s = p.template get<std::string>();
llama_tokens p;
if (first) {
p = common_tokenize(ctx, s, add_special, parse_special);
first = false;
} else {
p = common_tokenize(ctx, s, false, parse_special);
}
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
} else {
if (first) {
first = false;
}
prompt_tokens.push_back(p.template get<llama_token>());
}
}
} else {
auto s = json_prompt.template get<std::string>();
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
}
return prompt_tokens;
}
/**
* break the input "prompt" object into multiple prompt if needed, then tokenize them
* this supports these cases:
* - "prompt": "string"
* - "prompt": [12, 34, 56]
* - "prompt": [12, 34, "string", 56, 78]
* and multiple prompts (multi-tasks):
* - "prompt": ["string1", "string2"]
* - "prompt": ["string1", [12, 34, 56]]
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
*/
static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
std::vector<llama_tokens> result;
if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
// string or mixed
result.push_back(tokenize_mixed(ctx, json_prompt, add_special, parse_special));
} else if (json_is_array_of_numbers(json_prompt)) {
// array of tokens
result.push_back(json_prompt.get<llama_tokens>());
} else if (json_prompt.is_array()) {
// array of prompts
result.reserve(json_prompt.size());
for (const auto & p : json_prompt) {
if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
result.push_back(tokenize_mixed(ctx, p, add_special, parse_special));
} else if (json_is_array_of_numbers(p)) {
// array of tokens
result.push_back(p.get<llama_tokens>());
} else {
throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens");
}
}
} else {
throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
}
return result;
}
//
// template utils
//
// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
static llama_tokens format_rerank(const struct llama_model * model, const llama_tokens & query, const llama_tokens & doc) {
llama_tokens result;
result.reserve(doc.size() + query.size() + 4);
result.push_back(llama_token_bos(model));
result.insert(result.end(), query.begin(), query.end());
result.push_back(llama_token_eos(model));
result.push_back(llama_token_sep(model));
result.insert(result.end(), doc.begin(), doc.end());
result.push_back(llama_token_eos(model));
return result;
}
// format infill task
static llama_tokens format_infill(
const llama_context * ctx,
const json & input_prefix,
const json & input_suffix,
const json & input_extra,
const int n_batch,
const int n_predict,
const int n_ctx,
const bool spm_infill,
const llama_tokens & tokens_prompt
) {
// TODO: optimize this block by reducing memory allocations and movement
// use FIM repo-level pattern:
// ref: https://arxiv.org/pdf/2409.12186
//
// [FIM_REP]myproject
// [FIM_SEP]filename0
// extra chunk 0
// [FIM_SEP]filename1
// extra chunk 1
// ...
// [FIM_SEP]filename
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
//
llama_tokens extra_tokens;
extra_tokens.reserve(n_ctx);
auto model = llama_get_model(ctx);
auto tokens_prefix = tokenize_mixed(ctx, input_prefix, false, false);
auto tokens_suffix = tokenize_mixed(ctx, input_suffix, false, false);
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
// TODO: make project name an input
static const auto k_fim_repo = common_tokenize(ctx, "myproject\n", false, false);
extra_tokens.push_back(llama_token_fim_rep(model));
extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
}
for (const auto & chunk : input_extra) {
// { "text": string, "filename": string }
const std::string text = json_value(chunk, "text", std::string());
const std::string filename = json_value(chunk, "filename", std::string("tmp"));
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
const auto k_fim_file = common_tokenize(ctx, filename + "\n", false, false);
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model));
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
} else {
// chunk separator in binary form to avoid confusing the AI
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
static const auto k_chunk_prefix_tokens = common_tokenize(ctx, k_chunk_prefix_str, false, false);
extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
}
const auto chunk_tokens = common_tokenize(ctx, text, false, false);
extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
}
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
// TODO: current filename
static const auto k_fim_file = common_tokenize(ctx, "filename\n", false, false);
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model));
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
}
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
const int n_suffix_take = std::min<int>(tokens_suffix.size(), (n_batch/4));
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4) - 3);
// fill the rest of the context with extra chunks
const int n_extra_take = std::min<int>(std::max<int>(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size());
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
tokens_suffix.resize(n_suffix_take);
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix;
auto embd_end = spm_infill ? tokens_prefix : tokens_suffix;
if (llama_add_bos_token(model)) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
}
SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
// put the extra context before the FIM prefix
embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
embd_inp.push_back(llama_token_fim_mid(model));
return embd_inp;
}
// Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
std::vector<common_chat_msg> chat;
@@ -229,18 +471,6 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin
return std::string::npos;
}
static bool json_is_array_of_numbers(const json & data) {
if (data.is_array()) {
for (const auto & e : data) {
if (!e.is_number()) {
return false;
}
}
return true;
}
return false;
}
// TODO: reuse llama_detokenize
template <class Iter>
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {

6
flake.lock generated
View File

@@ -20,11 +20,11 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1728492678,
"narHash": "sha256-9UTxR8eukdg+XZeHgxW5hQA9fIKHsKCdOIUycTryeVw=",
"lastModified": 1729256560,
"narHash": "sha256-/uilDXvCIEs3C9l73JTACm4quuHUsIHcns1c+cHUJwA=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "5633bcff0c6162b9e4b5f1264264611e950c8ec7",
"rev": "4c2fcb090b1f3e5b47eaa7bd33913b574a11e0a0",
"type": "github"
},
"original": {

View File

@@ -1151,8 +1151,8 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
char * src_ptr = (char *) src->data;
char * dst_ptr = (char *) dst;
const char * src_ptr = (const char *) src->data;
char * dst_ptr = (char *) dst;
const int64_t ne0 = src->ne[0];
const int64_t nb0 = src->nb[0];
@@ -1162,7 +1162,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
const enum ggml_type type = src->type;
const int64_t ts = ggml_type_size(type);
const int64_t bs = ggml_blck_size(type);
int64_t i1_diff = i1_high - i1_low;
const int64_t i1_diff = i1_high - i1_low;
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
if (nb0 == ts && nb1 == ts*ne0/bs) {
@@ -1479,13 +1479,18 @@ static void ggml_cuda_op_mul_mat(
if (src0_is_contiguous) {
dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data;
} else {
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), ggml_nbytes(src0));
// If src0 is not contiguous it will be copied to a temporary buffer.
// This buffer needs to be cleared entirely because multiple regions will function as padding.
const size_t nbytes_data = ggml_nbytes(src0);
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
}
// If src0 is on a temporary compute buffers (partial offloading) there may be some padding that needs to be cleared:
// If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
const int64_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
const int64_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
const size_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream));
}
@@ -3141,7 +3146,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_ROPE:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_IM2COL:
return op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:

View File

@@ -91,9 +91,9 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t OH = is_2D ? dst->ne[2] : 1;
const int64_t OW = dst->ne[1];
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
const int64_t batch = src1->ne[3];
const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
const int64_t batch = src1->ne[is_2D ? 3 : 2];
const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
if(dst->type == GGML_TYPE_F16) {
im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);

View File

@@ -8,8 +8,6 @@ void ggml_cuda_op_mul_mat_q(
const int64_t ne00 = src0->ne[0];
const int64_t nb01 = src0->nb[1];
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
GGML_ASSERT(ne10 % QK8_1 == 0);
@@ -17,7 +15,7 @@ void ggml_cuda_op_mul_mat_q(
const int64_t ne0 = dst->ne[0];
const int64_t row_diff = row_high - row_low;
const int64_t stride00 = nb01 / ggml_type_size(src0->type);
const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;

View File

@@ -241,6 +241,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
GGML_METAL_KERNEL_TYPE_PAD_F32,
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
@@ -272,6 +274,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_SIN,
GGML_METAL_KERNEL_TYPE_COS,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
GGML_METAL_KERNEL_TYPE_COUNT
};
@@ -685,6 +689,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
@@ -716,6 +722,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
}
[metal_library release];
@@ -844,8 +852,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_IM2COL:
return op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_POOL_1D:
case GGML_OP_POOL_2D:
return false;
case GGML_OP_POOL_2D:
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_ARANGE:
@@ -2545,6 +2553,8 @@ static void ggml_metal_encode_node(
} break;
case GGML_OP_IM2COL:
{
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
@@ -2574,30 +2584,54 @@ static void ggml_metal_encode_node(
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
id<MTLComputePipelineState> pipeline = nil;
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
switch (dst->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
case GGML_TYPE_F32: {
pipeline = (is_gt_mttpt ?
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
:
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline);
} break;
case GGML_TYPE_F16: {
pipeline = (is_gt_mttpt ?
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
:
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline);
} break;
default: GGML_ABORT("fatal error");
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2];
[encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
[encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
[encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
[encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
[encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
[encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
if (is_gt_mttpt) {
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
[encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
} else {
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
}
} break;
case GGML_OP_UPSCALE:
{
@@ -3001,6 +3035,64 @@ static void ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_POOL_2D:
{
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
const int32_t * opts = dst->op_params;
enum ggml_op_pool op = opts[0];
id<MTLComputePipelineState> pipeline = nil;
switch (src0t) {
case GGML_TYPE_F32: {
switch(op) {
case GGML_OP_POOL_AVG:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
case GGML_OP_POOL_MAX:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
default: GGML_ASSERT(false && "not implemented");
}
} break;
default: GGML_ASSERT(false && "not implemented");
}
const int32_t k0 = opts[1];
const int32_t k1 = opts[2];
const int32_t s0 = opts[3];
const int32_t s1 = opts[4];
const int32_t p0 = opts[5];
const int32_t p1 = opts[6];
const int64_t IH = src0->ne[1];
const int64_t IW = src0->ne[0];
const int64_t N = dst->ne[3];
const int64_t OC = dst->ne[2];
const int64_t OH = dst->ne[1];
const int64_t OW = dst->ne[0];
const int64_t parallel_elements = N * OC * OH * OW;
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
[encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
[encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
[encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
[encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
[encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
[encoder setBytes:&parallel_elements length:sizeof(int64_t) atIndex:12];
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
} break;
default:
{
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));

View File

@@ -1933,6 +1933,85 @@ kernel void kernel_im2col(
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
typedef void (im2col_ext_t)(
device const float * x,
device char * dst,
constant int32_t & ofs0,
constant int32_t & ofs1,
constant int32_t & IW,
constant int32_t & IH,
constant int32_t & CHW,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int32_t & d0,
constant int32_t & d1,
constant int32_t & N,
constant int32_t & KH,
constant int32_t & KW,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template <typename T>
kernel void kernel_im2col_ext(
device const float * x,
device char * dst,
constant int32_t & ofs0,
constant int32_t & ofs1,
constant int32_t & IW,
constant int32_t & IH,
constant int32_t & CHW,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int32_t & d0,
constant int32_t & d1,
constant int32_t & N,
constant int32_t & KH,
constant int32_t & KW,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
const int32_t d = tgpig[0] / CHW;
const int32_t chw = tgpig[0] % CHW;
const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
const int32_t HW = tgpig[0] % KHW;
const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0];
if (tpitg_0 >= N) {
return;
}
const int32_t tpitg_1 = HW / KW;
const int32_t tpitg_2 = HW % KW;
const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
const int32_t offset_dst =
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
(tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
device T * pdst = (device T *) (dst);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
pdst[offset_dst] = 0.0f;
} else {
const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
kernel void kernel_upscale_f32(
device const char * src0,
device char * dst,
@@ -6372,3 +6451,102 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
kernel void kernel_pool_2d_max_f32(
device const float * src0,
device float * dst,
constant int32_t & k0,
constant int32_t & k1,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int64_t & IH,
constant int64_t & IW,
constant int64_t & OH,
constant int64_t & OW,
constant int64_t & parallel_elements,
uint gid[[thread_position_in_grid]]) {
if (gid >= parallel_elements) {
return;
}
const int idx = gid;
const int I_HW = IH * IW;
const int O_HW = OH * OW;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / OW;
const int cur_ow = idx % O_HW % OW;
device const float * i_ptr = src0 + nc * I_HW;
device float * o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * s1 - p1;
const int bh = MAX(0, start_h);
const int eh = MIN(IH, start_h + k1);
const int start_w = cur_ow * s0 - p0;
const int bw = MAX(0, start_w);
const int ew = MIN(IW, start_w + k0);
float res = -INFINITY;
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
res = MAX(res, i_ptr[i * IW + j]);
}
}
o_ptr[cur_oh * OW + cur_ow] = res;
}
kernel void kernel_pool_2d_avg_f32(
device const float * src0,
device float * dst,
constant int32_t & k0,
constant int32_t & k1,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int64_t & IH,
constant int64_t & IW,
constant int64_t & OH,
constant int64_t & OW,
constant int64_t & parallel_elements,
uint gid[[thread_position_in_grid]]) {
if (gid >= parallel_elements) {
return;
}
const int idx = gid;
const int I_HW = IH * IW;
const int O_HW = OH * OW;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / OW;
const int cur_ow = idx % O_HW % OW;
device const float * i_ptr = src0 + nc * I_HW;
device float * o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * s1 - p1;
const int bh = MAX(0, start_h);
const int eh = MIN(IH, start_h + k1);
const int start_w = cur_ow * s0 - p0;
const int bw = MAX(0, start_w);
const int ew = MIN(IW, start_w + k0);
// const float scale = 1. / ((eh - bh) * (ew - bw));
const float scale = 1. / (k0 * k1);
float res = 0;
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
float cur = i_ptr[i * IW + j];
res += cur * scale;
}
}
o_ptr[cur_oh * OW + cur_ow] = res;
}

View File

@@ -5173,6 +5173,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
if (op->op == GGML_OP_MUL_MAT) {
a = op->src[0];
b = op->src[1];
if (ggml_is_permuted(a) || ggml_is_permuted(b)) {
// TODO: fix like https://github.com/ggerganov/llama.cpp/pull/10021
return false;
}
} else {
a = op->src[2];
b = op->src[1];

View File

@@ -8,7 +8,6 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
const int nthreads = item_ct1.get_local_range(2);
const int nwarps = nthreads / WARP_SIZE;
assert(nwarps % WARP_SIZE == 0);
sycl::float2 mean_var = sycl::float2(0.f, 0.f);
for (int col = tid; col < ncols; col += block_size) {
@@ -55,7 +54,6 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
int end = start + group_size;
const int nthreads = item_ct1.get_local_range(2);
const int nwarps = nthreads / WARP_SIZE;
assert(nwarps % WARP_SIZE == 0);
start += item_ct1.get_local_id(2);
int nreduce = nwarps / WARP_SIZE;
@@ -144,7 +142,6 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
const int tid = item_ct1.get_local_id(2);
const int nthreads = item_ct1.get_local_range(2);
const int nwarps = nthreads / WARP_SIZE;
assert(nwarps % WARP_SIZE == 0);
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
@@ -202,6 +199,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
}
else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
const sycl::range<3> block_dims(1, 1, work_group_size);
/*
DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
@@ -244,6 +242,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
}
else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
const sycl::range<3> block_dims(1, 1, work_group_size);
/*
DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
@@ -290,6 +289,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
}
else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
const sycl::range<3> block_dims(1, 1, work_group_size);
/*
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed

View File

@@ -3464,7 +3464,7 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
size_t ggml_nbytes(const struct ggml_tensor * tensor) {
size_t nbytes;
size_t blck_size = ggml_blck_size(tensor->type);
const size_t blck_size = ggml_blck_size(tensor->type);
if (blck_size == 1) {
nbytes = ggml_type_size(tensor->type);
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
@@ -3852,10 +3852,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
},
};
for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) {
g_state.contexts[i].used = false;
}
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);

View File

@@ -1 +1 @@
2327bda7a55ac6b72614ac5ebd5c5a5e02553b9b
6dccc647264f5429df2624f36138f601e7ce23e5

View File

@@ -5177,6 +5177,57 @@ struct llama_model_loader {
}
};
// temporary allocate memory for the input batch if needed
static const llama_seq_id batch_default_seq_id = 0;
struct llama_batch_allocr {
std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<int8_t> logits;
struct llama_batch batch;
// optionally fulfill the batch returned by llama_batch_get_one
llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) {
batch = in_batch;
GGML_ASSERT(batch.n_tokens > 0);
if (!batch.pos) {
// determine the last position in KV cache
llama_pos last_pos = -1;
for (const auto & cell : ctx.kv_self.cells) {
if (cell.has_seq_id(batch_default_seq_id)) {
last_pos = std::max(last_pos, cell.pos);
}
}
last_pos++; // next position
pos.resize(batch.n_tokens);
for (int32_t i = 0; i < batch.n_tokens; i++) {
pos[i] = i+last_pos;
}
batch.pos = pos.data();
}
if (!batch.n_seq_id) {
n_seq_id.resize(batch.n_tokens);
for (int32_t i = 0; i < batch.n_tokens; i++) {
n_seq_id[i] = seq_id_0.size();
}
batch.n_seq_id = n_seq_id.data();
}
if (!batch.seq_id) {
seq_id.resize(batch.n_tokens + 1);
seq_id[batch.n_tokens] = NULL;
for (int32_t i = 0; i < batch.n_tokens; i++) {
seq_id[i] = seq_id_0.data();
}
batch.seq_id = seq_id.data();
}
if (!batch.logits) {
logits.resize(batch.n_tokens);
logits[logits.size() - 1] = true;
batch.logits = logits.data();
}
}
};
template<>
bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
uint32_t tmp;
@@ -17095,16 +17146,20 @@ static void llama_graph_compute(
//
static int llama_decode_internal(
llama_context & lctx,
llama_batch batch) {
llama_batch inp_batch) {
lctx.is_encoding = false;
const uint32_t n_tokens_all = batch.n_tokens;
if (n_tokens_all == 0) {
if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
// temporary allocate memory for the input batch if needed
llama_batch_allocr batch_allocr(lctx, inp_batch);
const llama_batch & batch = batch_allocr.batch;
const uint32_t n_tokens_all = batch.n_tokens;
const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
@@ -17409,17 +17464,20 @@ static int llama_decode_internal(
//
static int llama_encode_internal(
llama_context & lctx,
llama_batch batch) {
llama_batch inp_batch) {
lctx.is_encoding = true;
const uint32_t n_tokens = batch.n_tokens;
if (n_tokens == 0) {
if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
// temporary allocate memory for the input batch if needed
llama_batch_allocr batch_allocr(lctx, inp_batch);
const llama_batch & batch = batch_allocr.batch;
const uint32_t n_tokens = batch.n_tokens;
const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
@@ -21090,61 +21148,10 @@ void llama_batch_free(struct llama_batch batch) {
if (batch.logits) free(batch.logits);
}
// temporary allocate memory for the input batch if needed
static const llama_seq_id batch_default_seq_id = 0;
struct llama_batch_allocr {
std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<int8_t> logits;
struct llama_batch batch;
// optionally fulfill the batch returned by llama_batch_get_one
llama_batch_allocr(struct llama_context * ctx, struct llama_batch in_batch) {
batch = in_batch;
if (!batch.pos) {
// determine the last position in KV cache
llama_pos last_pos = -1;
for (const auto & cell : ctx->kv_self.cells) {
if (cell.has_seq_id(batch_default_seq_id)) {
last_pos = std::max(last_pos, cell.pos);
}
}
last_pos++; // next position
pos.resize(batch.n_tokens);
for (int32_t i = 0; i < batch.n_tokens; i++) {
pos[i] = i+last_pos;
}
batch.pos = pos.data();
}
if (!batch.n_seq_id) {
n_seq_id.resize(batch.n_tokens);
for (int32_t i = 0; i < batch.n_tokens; i++) {
n_seq_id[i] = seq_id_0.size();
}
batch.n_seq_id = n_seq_id.data();
}
if (!batch.seq_id) {
seq_id.resize(batch.n_tokens + 1);
seq_id[batch.n_tokens] = NULL;
for (int32_t i = 0; i < batch.n_tokens; i++) {
seq_id[i] = seq_id_0.data();
}
batch.seq_id = seq_id.data();
}
if (!batch.logits) {
logits.resize(batch.n_tokens);
logits[logits.size() - 1] = true;
batch.logits = logits.data();
}
}
};
int32_t llama_encode(
struct llama_context * ctx,
struct llama_batch batch) {
llama_batch_allocr batch_allocr(ctx, batch);
const int ret = llama_encode_internal(*ctx, batch_allocr.batch);
const int ret = llama_encode_internal(*ctx, batch);
if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
}
@@ -21155,8 +21162,7 @@ int32_t llama_encode(
int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch) {
llama_batch_allocr batch_allocr(ctx, batch);
const int ret = llama_decode_internal(*ctx, batch_allocr.batch);
const int ret = llama_decode_internal(*ctx, batch);
if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}

View File

@@ -1650,11 +1650,12 @@ struct test_mul_mat : public test_case {
const int64_t m;
const int64_t n;
const int64_t k;
const std::array<int64_t, 2> bs; // dims 3 and 4
const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
const std::array<int64_t, 2> bs; // dims 3 and 4
const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
const std::array<int64_t, 4> per; // permutation of dimensions
std::string vars() override {
return VARS_TO_STR7(type_a, type_b, m, n, k, bs, nr);
return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, per);
}
double max_nmse_err() override {
@@ -1669,17 +1670,44 @@ struct test_mul_mat : public test_case {
test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
int64_t m = 32, int64_t n = 32, int64_t k = 32,
std::array<int64_t, 2> bs = {10, 10},
std::array<int64_t, 2> nr = {2, 2})
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr) {}
std::array<int64_t, 2> nr = {2, 2},
std::array<int64_t, 4> per = {0, 1, 2, 3})
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0] , bs[1]);
ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
ggml_set_param(ctx, a);
ggml_set_param(ctx, b);
ggml_set_name(a, "a");
ggml_set_name(b, "b");
ggml_tensor * a;
ggml_tensor * b;
const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3);
if (npermuted > 0) {
GGML_ASSERT(npermuted == 2);
GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0);
GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0);
// Create tensors with the permuted dimensions, then permute them back to the dimensions given by m,n,k.
const int64_t ne_a[4] = {k, m, bs[0], bs[1]};
const int64_t ne_b[4] = {k, n, bs[0]*nr[0], bs[1]*nr[1]};
a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]);
b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);
ggml_set_param(ctx, a);
ggml_set_param(ctx, b);
ggml_set_name(a, "a");
ggml_set_name(b, "b");
a = ggml_permute(ctx, a, per[0], per[1], per[2], per[3]);
b = ggml_permute(ctx, b, per[0], per[1], per[2], per[3]);
ggml_set_name(a, "a_permuted");
ggml_set_name(b, "b_permuted");
} else {
a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
ggml_set_param(ctx, a);
ggml_set_param(ctx, b);
ggml_set_name(a, "a");
ggml_set_name(b, "b");
}
ggml_tensor * out = ggml_mul_mat(ctx, a, b);
ggml_set_name(out, "out");
@@ -3308,13 +3336,49 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
// test cases for 1D im2col
// im2col 1D
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
for (int s0 : {1, 3}) {
for (int p0 : {0, 3}) {
for (int d0 : {1, 3}) {
test_cases.emplace_back(new test_im2col(
GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 2, 2, 1}, {3, 2, 2, 1},
s0, 0, p0, 0, d0, 0, false));
}
}
}
// im2col 2D
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
for (int s0 : {1, 3}) {
for (int s1 : {1, 3}) {
for (int p0 : {0, 3}) {
for (int p1 : {0, 3}) {
for (int d0 : {1, 3}) {
for (int d1 : {1, 3}) {
test_cases.emplace_back(new test_im2col(
GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 2, 2}, {3, 3, 2, 2},
s0, s1, p0, p1, d0, d1, true));
}
}
}
}
}
}
// extra tests for im2col 2D
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 1024}, {3, 3, 2, 1024}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2048}, {3, 3, 1, 2048}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2048}, {3, 3, 2, 2048}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
// sycl backend will limit task global_range < MAX_INT
// test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
@@ -3442,13 +3506,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
#if 1
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
// test cases without permutation
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {1, 1}));
@@ -3457,6 +3522,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
// test cases with permutation
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
}
}
for (ggml_type type_a : other_types) {