Compare commits

..

47 Commits

Author SHA1 Message Date
Douglas Hanley
b54afce9f4 mostly style fixes; fix KQ_mask comment 2024-03-09 13:03:46 -06:00
DAN™
03acc82a85 Clean-up GritLM sample code. 2024-03-09 07:44:25 -05:00
Douglas Hanley
bd3d9fbfed allow to toggle embedding mode 2024-03-07 11:55:27 -06:00
Douglas Hanley
f618e5060a add to gitignore 2024-03-07 01:38:30 -06:00
Douglas Hanley
1ab6aeeeee gritlm embeddings are back babeee 2024-03-07 01:37:08 -06:00
Douglas Hanley
97936078b7 rebase to new embed 2024-03-05 23:23:17 -06:00
Douglas Hanley
805ae529c4 comment out debug printing 2024-03-05 13:44:42 -06:00
Douglas Hanley
a71842d7ef tabs to spaces 2024-03-05 13:44:42 -06:00
Douglas Hanley
e79195fc53 gritlm results match 2024-03-05 13:44:41 -06:00
Douglas Hanley
4be8fb18ed add gritlm example 2024-03-05 13:42:51 -06:00
Jared Van Bortel
bd836944f8 quants : use MM256_SET_M128I consistently to fix gcc 7 build (#5889) 2024-03-05 11:56:37 -05:00
ExtReMLapin
3de31677d3 grammars : blacklists character control set (#5888)
* Prevent control characters from being served in json string

* Prevent control characters from being served in json string (array)
2024-03-05 18:33:08 +02:00
Georgi Gerganov
82cb31eb93 Revert "grammars : don't allow to output unescaped new line in string (#5885)"
This reverts commit b1a4e994fd.
2024-03-05 15:56:24 +02:00
ExtReMLapin
b1a4e994fd grammars : don't allow to output unescaped new line in string (#5885)
* Don't allow grammar json array to output unescaped new line in string

* Don't allow new line in json object string
2024-03-05 15:44:29 +02:00
0cc4m
61d1c88e15 Vulkan Improvements (#5835)
* Improve dequant shaders, add fast q4_0 dequant

* Optimize dmmv non-kquants for GCN

Remove unnecessary SPIR-V shader duplication

* Fix q4_0 dequant dispatch sizes

Fix backend free bug

* Optimize dequant shaders for q4_1, q5_0, q5_1 and q8_0

* Add unary and binary op shader templates

* Fix Vulkan check results

* Enable non-contiguous support for simple ops

* Add argsort

Basic q4_0 mmq shader and unit test

* Speed up q4_0 dequant code, enable mmq for q4_0

* Rework matmul pipeline selection

* Add soft_max alibi support

* Add q4_1, q5_0, q5_1 and q8_0 dequant mat mat mul shaders

* Add environment variable GGML_VK_FORCE_MAX_ALLOCATION_SIZE to limit max buffer size

Rename GGML_VULKAN_DISABLE_F16 to GGML_VK_DISABLE_F16 for consistency
2024-03-05 13:33:42 +01:00
Neo Zhang Jianyu
21b0867433 [SYCL] fix mul_mat fault in CI/unit-test (#5862)
* fix mul_mat fault in cpy_f32_f16

* rm unused function

* add wait() for memcpy

* restore ci/run.sh, rename struct defination, fix bug in ggml_sycl_op_mul_mat_sycl

* fix format issue

* llama : fix segfault from unknown model arch name (#5820)

* llama : fix segfault from unknown model arch name

* llama : make all LLM maps const

This also requires using `std::map::at` instead of its `operator[]`
which does not exist for const maps.

* llama : name LLM_ARCH_UNKNOWN to "(unknown)"

This avoids errors from `std::map::at` when
getting the general name of the model architecture.
Using "(unknown)" instead of an empty string as per suggestion
https://github.com/ggerganov/llama.cpp/pull/5820#issuecomment-1973735284

* llama : remove redundant inner const for LLM_TENSOR_NAMES

The extra const won't do anything here as const maps
return const references to values.

Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>

* llama : remove redundant nullptr check in llm_arch_from_string

Since LLM_ARCH_NAMES is a const map, no spurious elements
with a NULL name are inserted anymore, so this check is dead code.

---------

Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>

* llama : refactor internal quantization functions (#5830)

* scripts : add pod-llama.sh

* ggml : IQ3_S improvements (#5829)

* iq3_s: somewhat faster AVX2 dot product

On Ryzen a 7950X TG-128 increases to 16 t/s from 15.5 t/s using
16 threads. For 8 threads it is 13.85 t/s vs 11.75 t/s.
PP-512 increases to 28.5 t/s from 23.8 t/s.

* iq3_s: somewhat faster ARM_NEON dot product

Still dog slow - 10.7 t/s up from 9.9 t/s.

* iq3_s: another small ARM_NEON improvement

10.7 -> 11.0 t/s. Using vmulq_s8 is faster than the xor - sub trick
that works best on AVX2.

* iq3_s: minor improvement on Metal

49.4 t/s -> 50.3 t/s

* iq3_s: PPL improvement

E.g., for a context of 4096 LLaMA-v2-7B goes to 5.1340 from 5.1653.

* iq3_s: use new grid everywhere

* Fix ARM_NEON

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>

* convert-hf : make model class definitions self-contained (#5825)

* convert : automatically fall back to HfVocab if tokenizer.model doesn't exist (#5821)

* ggml : fix IQ3_S AVX implementation (#5834)

ggml-ci

* llama : add abort_callback to interrupt computation (#5409)

* using abort_callback from ggml to stop llama computation

* format fix

* a brief explaining comment

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* server: tests: passkey challenge /  self-extend with context shift demo (#5832)

* server: tests: add models endpoint scenario

* server: /v1/models add some metadata

* server: tests: add debug field in context before scenario

* server: tests: download model from HF, add batch size

* server: tests: add passkey test

* server: tests: add group attention params

* server: do not truncate prompt tokens if self-extend through group attention is enabled

* server: logs: do not truncate log values

* server: tests - passkey - first good working value of nga

* server: tests: fix server timeout

* server: tests: fix passkey, add doc, fix regex content matching, fix timeout

* server: tests: fix regex content matching

* server: tests: schedule slow tests on master

* server: metrics: fix when no prompt processed

* server: tests: self-extend add llama-2-7B and Mixtral-8x7B-v0.1

* server: tests: increase timeout for completion

* server: tests: keep only the PHI-2 test

* server: tests: passkey add a negative test

* flake.lock: Update (#5842)

Flake lock file updates:

• Updated input 'flake-parts':
    'github:hercules-ci/flake-parts/b253292d9c0a5ead9bc98c4e9a26c6312e27d69f' (2024-02-01)
  → 'github:hercules-ci/flake-parts/f7b3c975cf067e56e7cda6cb098ebe3fb4d74ca2' (2024-03-01)
• Updated input 'flake-parts/nixpkgs-lib':
    'github:NixOS/nixpkgs/97b17f32362e475016f942bbdfda4a4a72a8a652?dir=lib' (2024-01-29)
  → 'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8?dir=lib' (2024-02-29)
• Updated input 'nixpkgs':
    'github:NixOS/nixpkgs/cbc4211f0afffe6dfd2478a62615dd5175a13f9a' (2024-02-23)
  → 'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8' (2024-02-29)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* server : init http requests thread pool with --parallel if set (#5836)

* ci : schedule slow server tests only on Release or on demand (#5839)

* llama : fix llama_copy_state_data with fragmented KV cache (#5840)

The row size of the saved states was based on kv_self.head while
it should be based on llama_kv_cache_cell_max.

Existing session files should still work.

* llama : fix llama_kv_cache_cell_max inability to return 1

I've also changed its return type to uint32_t,
because this function is always used to set the value of uint32_t variables,
and because the index already has this type.

* llama : fix state size calculation

Some bytes in the state were unaccounted for in llama_get_state_size.
Since the logits reserve so much space, it did not cause problems.

* gguf-dump : support i-quants (#5841)

Co-authored-by: Black_Fox <radekliska@gmail.com>

* llama : allow for user specified embedding pooling type (#5849)

* allow for user specified pooling type

* llama : use enum types over int

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* readme : add API changes section

* cuda : fix data race in soft max (#5853)

* main : support special tokens as reverse/anti prompt (#5847)

* Support special tokens as reverse/anti prompt.

* Tokenize antiprompts only once.

* main : minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* common : use LLAMA_DEFAULT_SEED (#5855)

* add some new ops, fix some operators and add batch operations to certain operators. (ggml/747)

* cuda: fix group_norm

* cuda: add batch inference support for ggml_pad/ggml_upscale

* add ggml_arrange

* add ggml_timestep_embedding

* update ggml_arange/ggml_timestep_embedding tests

* cuda: fix im2col

* add ggml_arange/ggml_timestep_embbeding support for metal backend

* fix some bugs

* fix some bugs

* Update ggml.h

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-cuda.cu

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.m

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.m

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.metal

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* modify according to the review comments

* ggml : fix compile warnings + code style

* ggml : normalize compute_forward calls + fix seg fault in debug

* minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>

* sync : ggml

* add alias for chat template (#5858)

* speculative : implement stochastic speculative sampling (#5625)

* (WIP) Implement stochastic speculative decoding

* sample from residual distribution on draft accept failure

* fix #5657: force greedy sampling with probs when temp is 0

* remove p_accept parameter

* fix style

* remove unused variables

* add srand() in speculative.cpp

* replace use of rand() with mt19937 sampling

* fixes based on review (@JohannesGaessler)

* fix r random generation

* randomly select next sequence to verify + fix bug in memory freeing

* fix bug in active_seqs sync

* fix uniform int distribution initialization

* remove warnings from comparison between int and size_t

* check grammar in `llama_sample_probability_distribution_impl`

* remove malloc code by utilizing vectors

* add PR link to README

* cmake : handle cases where git index is not found in .git (#5844)

* Update CMakeLists.txt

* Update CMakeLists.txt

* ggml : introduce ggml_status (ggml/750)

* using enum as an exit code instead of macros

* update return type from enum to unsigned int

* indentation fix

* compound update
ggml_compute_exit_code -> ggml_status
changed ggml_status from a bit-field type to simple codes
ggml_status to string cast

* ggml_status to string cast

* GGML_CALL was removed

Co-authored-by: slaren <slarengh@gmail.com>

---------

Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* sync : ggml

ggml-ci

* ggml : fix unknown status (#0)

* flake : fix

* llama : fix embeddings (#5796)

* llama : fix embeddings

ggml-ci

* llama : do not use KV cache for non-causal models

ggml-ci

* embeddings : fix llama_batch_init arg

* llama : add pooling switch

* llama : distinguish token vs sequence embeddings

ggml-ci

* llama : assert pooling tensor

* llama : simplify causal mask condition

ggml-ci

* llama : assert input batch with pooling enabled

* readme : update API changes list

* nix: static build (#5814)

* fix speculative decoding build on windows (#5874)

* rebase and rm tailing space

---------

Co-authored-by: LiangtaoJin <liang-tao.jin@intel.com>
Co-authored-by: compilade <113953597+compilade@users.noreply.github.com>
Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Kawrakow <48489457+ikawrakow@users.noreply.github.com>
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Co-authored-by: Jared Van Bortel <jared@nomic.ai>
Co-authored-by: Michael Podvitskiy <podvitskiymichael@gmail.com>
Co-authored-by: Pierrick Hymbert <pierrick.hymbert@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Nindaleth <Nindaleth@users.noreply.github.com>
Co-authored-by: Black_Fox <radekliska@gmail.com>
Co-authored-by: Douglas Hanley <thesecretaryofwar@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: DAN™ <dranger003@gmail.com>
Co-authored-by: leejet <leejet714@gmail.com>
Co-authored-by: Minsoo Cheong <54794500+mscheong01@users.noreply.github.com>
Co-authored-by: Dane Madsen <dane_madsen@hotmail.com>
Co-authored-by: hutli <6594598+hutli@users.noreply.github.com>
Co-authored-by: Jeffrey Quesnelle <emozilla@nousresearch.com>
2024-03-05 13:38:35 +05:30
Minsoo Cheong
6a87ac3a52 fix editorconfig check break (#5879) 2024-03-05 11:42:23 +05:30
Jeffrey Quesnelle
29eee40474 fix speculative decoding build on windows (#5874) 2024-03-04 22:23:06 -05:00
hutli
1d41d6f7c2 nix: static build (#5814) 2024-03-04 17:33:08 -08:00
Georgi Gerganov
29ae62d2ae llama : fix embeddings (#5796)
* llama : fix embeddings

ggml-ci

* llama : do not use KV cache for non-causal models

ggml-ci

* embeddings : fix llama_batch_init arg

* llama : add pooling switch

* llama : distinguish token vs sequence embeddings

ggml-ci

* llama : assert pooling tensor

* llama : simplify causal mask condition

ggml-ci

* llama : assert input batch with pooling enabled

* readme : update API changes list
2024-03-04 22:31:20 +02:00
Georgi Gerganov
e0843afe1b flake : fix 2024-03-04 21:50:50 +02:00
Georgi Gerganov
a1c6d96ed8 ggml : fix unknown status (#0) 2024-03-04 20:54:23 +02:00
Georgi Gerganov
efd8533ef8 sync : ggml
ggml-ci
2024-03-04 20:54:23 +02:00
Michael Podvitskiy
9fa2627347 ggml : introduce ggml_status (ggml/750)
* using enum as an exit code instead of macros

* update return type from enum to unsigned int

* indentation fix

* compound update
ggml_compute_exit_code -> ggml_status
changed ggml_status from a bit-field type to simple codes
ggml_status to string cast

* ggml_status to string cast

* GGML_CALL was removed

Co-authored-by: slaren <slarengh@gmail.com>

---------

Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-03-04 20:54:23 +02:00
Dane Madsen
fe52be11e3 cmake : handle cases where git index is not found in .git (#5844)
* Update CMakeLists.txt

* Update CMakeLists.txt
2024-03-04 20:26:55 +02:00
Minsoo Cheong
6d341ab6c5 speculative : implement stochastic speculative sampling (#5625)
* (WIP) Implement stochastic speculative decoding

* sample from residual distribution on draft accept failure

* fix #5657: force greedy sampling with probs when temp is 0

* remove p_accept parameter

* fix style

* remove unused variables

* add srand() in speculative.cpp

* replace use of rand() with mt19937 sampling

* fixes based on review (@JohannesGaessler)

* fix r random generation

* randomly select next sequence to verify + fix bug in memory freeing

* fix bug in active_seqs sync

* fix uniform int distribution initialization

* remove warnings from comparison between int and size_t

* check grammar in `llama_sample_probability_distribution_impl`

* remove malloc code by utilizing vectors

* add PR link to README
2024-03-04 20:24:00 +02:00
Xuan Son Nguyen
4ffcdce2ff add alias for chat template (#5858) 2024-03-04 12:22:08 +01:00
Georgi Gerganov
a0fc62661f sync : ggml 2024-03-04 10:40:04 +02:00
leejet
7d43c585dc add some new ops, fix some operators and add batch operations to certain operators. (ggml/747)
* cuda: fix group_norm

* cuda: add batch inference support for ggml_pad/ggml_upscale

* add ggml_arrange

* add ggml_timestep_embedding

* update ggml_arange/ggml_timestep_embedding tests

* cuda: fix im2col

* add ggml_arange/ggml_timestep_embbeding support for metal backend

* fix some bugs

* fix some bugs

* Update ggml.h

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-cuda.cu

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.m

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.m

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.metal

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* modify according to the review comments

* ggml : fix compile warnings + code style

* ggml : normalize compute_forward calls + fix seg fault in debug

* minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>
2024-03-04 10:39:10 +02:00
DAN™
82f3e668ad common : use LLAMA_DEFAULT_SEED (#5855) 2024-03-04 10:08:19 +02:00
DAN™
5a51cc1bb4 main : support special tokens as reverse/anti prompt (#5847)
* Support special tokens as reverse/anti prompt.

* Tokenize antiprompts only once.

* main : minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-03-04 09:57:20 +02:00
slaren
67be2ce101 cuda : fix data race in soft max (#5853) 2024-03-03 14:26:18 +01:00
Georgi Gerganov
231ae28f07 readme : add API changes section 2024-03-03 12:44:03 +02:00
Douglas Hanley
475df1d6cf llama : allow for user specified embedding pooling type (#5849)
* allow for user specified pooling type

* llama : use enum types over int

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-03-03 12:40:27 +02:00
Nindaleth
87c2e8b279 gguf-dump : support i-quants (#5841)
Co-authored-by: Black_Fox <radekliska@gmail.com>
2024-03-03 10:43:42 +02:00
compilade
de9692a7d2 llama : fix llama_copy_state_data with fragmented KV cache (#5840)
The row size of the saved states was based on kv_self.head while
it should be based on llama_kv_cache_cell_max.

Existing session files should still work.

* llama : fix llama_kv_cache_cell_max inability to return 1

I've also changed its return type to uint32_t,
because this function is always used to set the value of uint32_t variables,
and because the index already has this type.

* llama : fix state size calculation

Some bytes in the state were unaccounted for in llama_get_state_size.
Since the logits reserve so much space, it did not cause problems.
2024-03-03 10:41:55 +02:00
Pierrick Hymbert
e6029348e8 ci : schedule slow server tests only on Release or on demand (#5839) 2024-03-03 10:35:23 +02:00
Pierrick Hymbert
8ef969afce server : init http requests thread pool with --parallel if set (#5836) 2024-03-03 09:48:36 +02:00
Georgi Gerganov
fa974646e1 flake.lock: Update (#5842)
Flake lock file updates:

• Updated input 'flake-parts':
    'github:hercules-ci/flake-parts/b253292d9c0a5ead9bc98c4e9a26c6312e27d69f' (2024-02-01)
  → 'github:hercules-ci/flake-parts/f7b3c975cf067e56e7cda6cb098ebe3fb4d74ca2' (2024-03-01)
• Updated input 'flake-parts/nixpkgs-lib':
    'github:NixOS/nixpkgs/97b17f32362e475016f942bbdfda4a4a72a8a652?dir=lib' (2024-01-29)
  → 'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8?dir=lib' (2024-02-29)
• Updated input 'nixpkgs':
    'github:NixOS/nixpkgs/cbc4211f0afffe6dfd2478a62615dd5175a13f9a' (2024-02-23)
  → 'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8' (2024-02-29)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-03-02 20:11:31 -08:00
Pierrick Hymbert
9731134296 server: tests: passkey challenge / self-extend with context shift demo (#5832)
* server: tests: add models endpoint scenario

* server: /v1/models add some metadata

* server: tests: add debug field in context before scenario

* server: tests: download model from HF, add batch size

* server: tests: add passkey test

* server: tests: add group attention params

* server: do not truncate prompt tokens if self-extend through group attention is enabled

* server: logs: do not truncate log values

* server: tests - passkey - first good working value of nga

* server: tests: fix server timeout

* server: tests: fix passkey, add doc, fix regex content matching, fix timeout

* server: tests: fix regex content matching

* server: tests: schedule slow tests on master

* server: metrics: fix when no prompt processed

* server: tests: self-extend add llama-2-7B and Mixtral-8x7B-v0.1

* server: tests: increase timeout for completion

* server: tests: keep only the PHI-2 test

* server: tests: passkey add a negative test
2024-03-02 22:00:14 +01:00
Michael Podvitskiy
4a6e2d6142 llama : add abort_callback to interrupt computation (#5409)
* using abort_callback from ggml to stop llama computation

* format fix

* a brief explaining comment

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-03-02 21:52:25 +02:00
Georgi Gerganov
494c870326 ggml : fix IQ3_S AVX implementation (#5834)
ggml-ci
2024-03-02 20:00:49 +02:00
Jared Van Bortel
4d4d2366fc convert : automatically fall back to HfVocab if tokenizer.model doesn't exist (#5821) 2024-03-02 12:27:26 -05:00
Jared Van Bortel
c7a0ad8ec9 convert-hf : make model class definitions self-contained (#5825) 2024-03-02 12:21:47 -05:00
Kawrakow
bbde6eb256 ggml : IQ3_S improvements (#5829)
* iq3_s: somewhat faster AVX2 dot product

On Ryzen a 7950X TG-128 increases to 16 t/s from 15.5 t/s using
16 threads. For 8 threads it is 13.85 t/s vs 11.75 t/s.
PP-512 increases to 28.5 t/s from 23.8 t/s.

* iq3_s: somewhat faster ARM_NEON dot product

Still dog slow - 10.7 t/s up from 9.9 t/s.

* iq3_s: another small ARM_NEON improvement

10.7 -> 11.0 t/s. Using vmulq_s8 is faster than the xor - sub trick
that works best on AVX2.

* iq3_s: minor improvement on Metal

49.4 t/s -> 50.3 t/s

* iq3_s: PPL improvement

E.g., for a context of 4096 LLaMA-v2-7B goes to 5.1340 from 5.1653.

* iq3_s: use new grid everywhere

* Fix ARM_NEON

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2024-03-02 17:00:51 +02:00
Georgi Gerganov
ef2cd694c4 scripts : add pod-llama.sh 2024-03-02 16:54:20 +02:00
Xuan Son Nguyen
6c32d8c7ad llama : refactor internal quantization functions (#5830) 2024-03-02 16:19:09 +02:00
62 changed files with 48139 additions and 47946 deletions

View File

@@ -1,5 +1,6 @@
{
lib,
glibc,
config,
stdenv,
mkShell,
@@ -30,6 +31,11 @@
useRocm ? config.rocmSupport,
useVulkan ? false,
llamaVersion ? "0.0.0", # Arbitrary version, substituted by the flake
# It's necessary to consistently use backendStdenv when building with CUDA support,
# otherwise we get libstdc++ errors downstream.
effectiveStdenv ? if useCuda then cudaPackages.backendStdenv else stdenv,
enableStatic ? effectiveStdenv.hostPlatform.isStatic
}@inputs:
let
@@ -41,10 +47,7 @@ let
versionOlder
;
# It's necessary to consistently use backendStdenv when building with CUDA support,
# otherwise we get libstdc++ errors downstream.
stdenv = throw "Use effectiveStdenv instead";
effectiveStdenv = if useCuda then cudaPackages.backendStdenv else inputs.stdenv;
suffices =
lib.optionals useBlas [ "BLAS" ]
@@ -167,6 +170,9 @@ effectiveStdenv.mkDerivation (
# TODO: Replace with autoAddDriverRunpath
# once https://github.com/NixOS/nixpkgs/pull/275241 has been merged
cudaPackages.autoAddOpenGLRunpathHook
]
++ optionals (effectiveStdenv.hostPlatform.isGnu && enableStatic) [
glibc.static
];
buildInputs =
@@ -181,7 +187,7 @@ effectiveStdenv.mkDerivation (
[
(cmakeBool "LLAMA_NATIVE" false)
(cmakeBool "LLAMA_BUILD_SERVER" true)
(cmakeBool "BUILD_SHARED_LIBS" true)
(cmakeBool "BUILD_SHARED_LIBS" (!enableStatic))
(cmakeBool "CMAKE_SKIP_BUILD_RPATH" true)
(cmakeBool "LLAMA_BLAS" useBlas)
(cmakeBool "LLAMA_CLBLAST" useOpenCL)
@@ -190,6 +196,7 @@ effectiveStdenv.mkDerivation (
(cmakeBool "LLAMA_METAL" useMetalKit)
(cmakeBool "LLAMA_MPI" useMpi)
(cmakeBool "LLAMA_VULKAN" useVulkan)
(cmakeBool "LLAMA_STATIC" enableStatic)
]
++ optionals useCuda [
(

View File

@@ -3,6 +3,11 @@ name: Server
on:
workflow_dispatch: # allows manual triggering
inputs:
slow_tests:
description: 'Run slow tests'
required: true
type: boolean
push:
branches:
- master
@@ -10,6 +15,8 @@ on:
pull_request:
types: [opened, synchronize, reopened]
paths: ['.github/workflows/server.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'examples/server/tests/**.*']
schedule:
- cron: '0 0 * * *'
jobs:
server:
@@ -70,14 +77,15 @@ jobs:
run: |
pip install -r examples/server/tests/requirements.txt
- name: Download models
id: download_models
run: |
cd examples/server/tests
../../../scripts/hf.sh --repo ggml-org/models --file tinyllamas/stories260K.gguf
- name: Tests
id: server_integration_test
id: server_integration_tests
run: |
cd examples/server/tests
PORT=8888 ./tests.sh
- name: Slow tests
id: server_integration_tests_slow
if: ${{ github.event.schedule != '' && matrix.build_type == 'Release' || github.event.inputs.slow_tests == 'true' }}
run: |
cd examples/server/tests
PORT=8888 ./tests.sh --stop --no-skipped --no-capture --tags slow

1
.gitignore vendored
View File

@@ -45,6 +45,7 @@ models-mnt
/embedding
/gguf
/gguf-llama-simple
/gritlm
/imatrix
/infill
/libllama.so

View File

@@ -2,7 +2,7 @@
BUILD_TARGETS = \
main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey tests/test-c.o
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey gritlm tests/test-c.o
# Binaries only useful for tests
TEST_TARGETS = \
@@ -720,6 +720,10 @@ embedding: examples/embedding/embedding.cpp ggml.o llama.o $(C
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
gritlm: examples/gritlm/gritlm.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

View File

@@ -8,6 +8,11 @@
Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) in pure C/C++
### Recent API changes
- [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796
- [2024 Mar 3] `struct llama_context_params` https://github.com/ggerganov/llama.cpp/pull/5849
### Hot topics
- The `api_like_OAI.py` script has been removed - use `server` instead ([#5766](https://github.com/ggerganov/llama.cpp/issues/5766#issuecomment-1969037761))
@@ -786,7 +791,7 @@ And after 4.45 hours, you will have the final perplexity.
### Interactive mode
If you want a more ChatGPT-like experience, you can run in interactive mode by passing `-i` as a parameter.
In this mode, you can always interrupt generation by pressing Ctrl+C and entering one or more lines of text, which will be converted into tokens and appended to the current context. You can also specify a *reverse prompt* with the parameter `-r "reverse prompt string"`. This will result in user input being prompted whenever the exact tokens of the reverse prompt string are encountered in the generation. A typical use is to use a prompt that makes LLaMa emulate a chat between multiple users, say Alice and Bob, and pass `-r "Alice:"`.
In this mode, you can always interrupt generation by pressing Ctrl+C and entering one or more lines of text, which will be converted into tokens and appended to the current context. You can also specify a *reverse prompt* with the parameter `-r "reverse prompt string"`. This will result in user input being prompted whenever the exact tokens of the reverse prompt string are encountered in the generation. A typical use is to use a prompt that makes LLaMA emulate a chat between multiple users, say Alice and Bob, and pass `-r "Alice:"`.
Here is an example of a few-shot interaction, invoked with the command
@@ -850,7 +855,7 @@ Sample run:
```
== Running in interactive mode. ==
- Press Ctrl+C to interject at any time.
- Press Return to return control to LLaMa.
- Press Return to return control to LLaMA.
- If you want to submit another line, end your input in '\'.
Below is an instruction that describes a task. Write a response that appropriately completes the request.

View File

@@ -19,7 +19,12 @@ if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/../.git")
endif()
endif()
set(GIT_INDEX "${GIT_DIR}/index")
if(EXISTS "${GIT_DIR}/index")
set(GIT_INDEX "${GIT_DIR}/index")
else()
message(WARNING "Git index not found in git repository.")
set(GIT_INDEX "")
endif()
else()
message(WARNING "Git repository not found; to enable automatic generation of build info, make sure Git is installed and the project is a Git repository.")
set(GIT_INDEX "")

View File

@@ -335,6 +335,16 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.yarn_beta_slow = std::stof(argv[i]);
} else if (arg == "--pooling") {
if (++i >= argc) {
invalid_param = true;
break;
}
std::string value(argv[i]);
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else { invalid_param = true; break; }
} else if (arg == "--defrag-thold" || arg == "-dt") {
if (++i >= argc) {
invalid_param = true;
@@ -503,12 +513,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_sequences = std::stoi(argv[i]);
} else if (arg == "--p-accept" || arg == "-pa") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.p_accept = std::stof(argv[i]);
} else if (arg == "--p-split" || arg == "-ps") {
if (++i >= argc) {
invalid_param = true;
@@ -1014,6 +1018,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" --pooling {none,mean,cls}\n");
printf(" pooling type for embeddings, use model default if unspecified\n");
printf(" -dt N, --defrag-thold N\n");
printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
@@ -1032,7 +1038,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel);
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
printf(" -pa N, --p-accept N speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept);
printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
@@ -1287,7 +1292,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.seed = params.seed;
cparams.logits_all = params.logits_all;
cparams.embedding = params.embedding;
cparams.embeddings = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type;
cparams.rope_freq_base = params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale;
@@ -1296,6 +1301,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.pooling_type = params.pooling_type;
cparams.defrag_thold = params.defrag_thold;
cparams.offload_kqv = !params.no_kv_offload;

View File

@@ -43,7 +43,7 @@ extern char const *LLAMA_BUILD_TARGET;
int32_t get_num_physical_cores();
struct gpt_params {
uint32_t seed = -1; // RNG seed
uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed
int32_t n_threads = get_num_physical_cores();
int32_t n_threads_draft = -1;
@@ -53,11 +53,10 @@ struct gpt_params {
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_draft = 8; // number of tokens to draft during speculative decoding
int32_t n_draft = 5; // number of tokens to draft during speculative decoding
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_parallel = 1; // number of parallel sequences to decode
int32_t n_sequences = 1; // number of sequences to decode
float p_accept = 0.5f; // speculative decoding accept probability
float p_split = 0.1f; // speculative decoding split probability
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
@@ -76,8 +75,11 @@ struct gpt_params {
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
float defrag_thold = -1.0f; // KV cache defragmentation threshold
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
// // sampling parameters
struct llama_sampling_params sparams;

View File

@@ -295,6 +295,77 @@ static llama_token llama_sampling_sample_impl(
return id;
}
static llama_token_data_array llama_sample_probability_distribution_impl(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
const llama_sampling_params & params = ctx_sampling->params;
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq;
const float penalty_present = params.penalty_present;
const bool penalize_nl = params.penalize_nl;
auto & prev = ctx_sampling->prev;
auto & cur = ctx_sampling->cur;
// Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx);
// Declare original_logits at the beginning of the function scope
std::vector<float> original_logits;
// apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}
if (ctx_cfg) {
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
}
cur.clear();
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
// apply penalties
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
if (penalty_tokens_used_size) {
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
llama_sample_repetition_penalties(ctx_main, &cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
cur_p.data[idx].logit = nl_logit;
break;
}
}
}
}
// apply grammar checks
if (ctx_sampling->grammar != NULL) {
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
}
llama_sample_softmax(ctx_main, &cur_p);
return cur_p;
}
llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
@@ -304,6 +375,14 @@ llama_token llama_sampling_sample(
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
}
llama_token_data_array llama_sampling_probability_distribution(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx);
}
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,

View File

@@ -131,6 +131,13 @@ llama_token llama_sampling_sample(
struct llama_context * ctx_cfg,
int idx = 0);
// returns the probability that token of given id will be sampled
llama_token_data_array llama_sampling_probability_distribution(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
int idx = 0);
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,

View File

@@ -8,9 +8,10 @@ import json
import os
import re
import sys
from abc import ABC, abstractmethod
from enum import IntEnum
from pathlib import Path
from typing import TYPE_CHECKING, Any, ContextManager, Iterator, Sequence, cast
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterator, Sequence, TypeVar, cast
import numpy as np
import torch
@@ -36,7 +37,12 @@ class SentencePieceTokenTypes(IntEnum):
BYTE = 6
class Model:
AnyModel = TypeVar("AnyModel", bound="type[Model]")
class Model(ABC):
_model_classes: dict[str, type[Model]] = {}
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool):
self.dir_model = dir_model
self.ftype = ftype
@@ -47,10 +53,14 @@ class Model:
self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin")
self.part_names = self._get_part_names()
self.hparams = Model.load_hparams(self.dir_model)
self.model_arch = self._get_model_architecture()
self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=False)
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
@property
@abstractmethod
def model_arch(self) -> gguf.MODEL_ARCH:
pass
def find_hparam(self, keys: Sequence[str], optional: bool = False) -> Any:
key = next((k for k in keys if k in self.hparams), None)
if key is not None:
@@ -176,55 +186,22 @@ class Model:
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
return json.load(f)
@staticmethod
def from_model_architecture(model_architecture):
if model_architecture == "GPTNeoXForCausalLM":
return GPTNeoXModel
if model_architecture == "BloomForCausalLM":
return BloomModel
if model_architecture == "MPTForCausalLM":
return MPTModel
if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
return BaichuanModel
if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
return FalconModel
if model_architecture == "GPTBigCodeForCausalLM":
return StarCoderModel
if model_architecture == "GPTRefactForCausalLM":
return RefactModel
if model_architecture == "PersimmonForCausalLM":
return PersimmonModel
if model_architecture in ("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return StableLMModel
if model_architecture == "QWenLMHeadModel":
return QwenModel
if model_architecture == "Qwen2ForCausalLM":
return Model
if model_architecture == "MixtralForCausalLM":
return MixtralModel
if model_architecture == "GPT2LMHeadModel":
return GPT2Model
if model_architecture == "PhiForCausalLM":
return Phi2Model
if model_architecture == "PlamoForCausalLM":
return PlamoModel
if model_architecture == "CodeShellForCausalLM":
return CodeShellModel
if model_architecture == "OrionForCausalLM":
return OrionModel
if model_architecture == "InternLM2ForCausalLM":
return InternLM2Model
if model_architecture == "MiniCPMForCausalLM":
return MiniCPMModel
if model_architecture == "BertModel":
return BertModel
if model_architecture == "NomicBertModel":
return NomicBertModel
if model_architecture == "GemmaForCausalLM":
return GemmaModel
if model_architecture == "Starcoder2ForCausalLM":
return Model
return Model
@classmethod
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
assert names
def func(modelcls: type[Model]):
for name in names:
cls._model_classes[name] = modelcls
return modelcls
return func
@classmethod
def from_model_architecture(cls, arch):
try:
return cls._model_classes[arch]
except KeyError:
raise NotImplementedError(f'Architecture {arch!r} not supported!') from None
def _is_model_safetensors(self) -> bool:
return Model.count_model_parts(self.dir_model, ".safetensors") > 0
@@ -239,57 +216,6 @@ class Model:
return ("pytorch_model.bin",)
return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
def _get_model_architecture(self) -> gguf.MODEL_ARCH:
arch = self.hparams["architectures"][0]
if arch == "GPTNeoXForCausalLM":
return gguf.MODEL_ARCH.GPTNEOX
if arch == "BloomForCausalLM":
return gguf.MODEL_ARCH.BLOOM
if arch == "MPTForCausalLM":
return gguf.MODEL_ARCH.MPT
if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
return gguf.MODEL_ARCH.BAICHUAN
if arch in ("FalconForCausalLM", "RWForCausalLM"):
return gguf.MODEL_ARCH.FALCON
if arch == "GPTBigCodeForCausalLM":
return gguf.MODEL_ARCH.STARCODER
if arch == "GPTRefactForCausalLM":
return gguf.MODEL_ARCH.REFACT
if arch == "PersimmonForCausalLM":
return gguf.MODEL_ARCH.PERSIMMON
if arch in ("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return gguf.MODEL_ARCH.STABLELM
if arch == "QWenLMHeadModel":
return gguf.MODEL_ARCH.QWEN
if arch == "Qwen2ForCausalLM":
return gguf.MODEL_ARCH.QWEN2
if arch == "MixtralForCausalLM":
return gguf.MODEL_ARCH.LLAMA
if arch == "GPT2LMHeadModel":
return gguf.MODEL_ARCH.GPT2
if arch == "PhiForCausalLM":
return gguf.MODEL_ARCH.PHI2
if arch == "PlamoForCausalLM":
return gguf.MODEL_ARCH.PLAMO
if arch == "CodeShellForCausalLM":
return gguf.MODEL_ARCH.CODESHELL
if arch == "OrionForCausalLM":
return gguf.MODEL_ARCH.ORION
if arch == "InternLM2ForCausalLM":
return gguf.MODEL_ARCH.INTERNLM2
if arch == "MiniCPMForCausalLM":
return gguf.MODEL_ARCH.MINICPM
if arch == "BertModel":
return gguf.MODEL_ARCH.BERT
if arch == "NomicBertModel":
return gguf.MODEL_ARCH.NOMIC_BERT
if arch == "GemmaForCausalLM":
return gguf.MODEL_ARCH.GEMMA
if arch == "Starcoder2ForCausalLM":
return gguf.MODEL_ARCH.STARCODER2
raise NotImplementedError(f'Architecture "{arch}" not supported!')
def _set_vocab_gpt2(self):
dir_model = self.dir_model
hparams = self.hparams
@@ -457,7 +383,10 @@ class Model:
special_vocab.add_to_gguf(self.gguf_writer)
@Model.register("GPTNeoXForCausalLM")
class GPTNeoXModel(Model):
model_arch = gguf.MODEL_ARCH.GPTNEOX
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
@@ -474,7 +403,10 @@ class GPTNeoXModel(Model):
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
@Model.register("BloomForCausalLM")
class BloomModel(Model):
model_arch = gguf.MODEL_ARCH.BLOOM
def set_gguf_parameters(self):
self.gguf_writer.add_name("Bloom")
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
@@ -566,7 +498,10 @@ class BloomModel(Model):
print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
@Model.register("MPTForCausalLM")
class MPTModel(Model):
model_arch = gguf.MODEL_ARCH.MPT
def set_gguf_parameters(self):
block_count = self.hparams["n_layers"]
self.gguf_writer.add_name(self.dir_model.name)
@@ -629,7 +564,10 @@ class MPTModel(Model):
self.gguf_writer.add_tensor(new_name, data)
@Model.register("OrionForCausalLM")
class OrionModel(Model):
model_arch = gguf.MODEL_ARCH.ORION
def set_vocab(self):
self._set_vocab_sentencepiece()
@@ -708,7 +646,10 @@ class OrionModel(Model):
self.gguf_writer.add_tensor(new_name, data)
@Model.register("BaichuanForCausalLM", "BaiChuanForCausalLM")
class BaichuanModel(Model):
model_arch = gguf.MODEL_ARCH.BAICHUAN
def set_vocab(self):
self._set_vocab_sentencepiece()
@@ -823,7 +764,10 @@ class BaichuanModel(Model):
return weights[r * n_part:r * n_part + r, ...]
@Model.register("FalconForCausalLM", "RWForCausalLM")
class FalconModel(Model):
model_arch = gguf.MODEL_ARCH.FALCON
def set_gguf_parameters(self):
block_count = self.hparams.get("num_hidden_layers")
if block_count is None:
@@ -916,7 +860,10 @@ class FalconModel(Model):
self.gguf_writer.add_tensor(new_name, data)
@Model.register("GPTBigCodeForCausalLM")
class StarCoderModel(Model):
model_arch = gguf.MODEL_ARCH.STARCODER
def set_gguf_parameters(self):
block_count = self.hparams["n_layer"]
@@ -931,7 +878,10 @@ class StarCoderModel(Model):
self.gguf_writer.add_file_type(self.ftype)
@Model.register("GPTRefactForCausalLM")
class RefactModel(Model):
model_arch = gguf.MODEL_ARCH.REFACT
def set_gguf_parameters(self):
hidden_dim = self.hparams["n_embd"]
inner_dim = 4 * hidden_dim
@@ -1015,7 +965,10 @@ class RefactModel(Model):
self.gguf_writer.add_tensor(new_name, data)
@Model.register("PersimmonForCausalLM")
class PersimmonModel(Model):
model_arch = gguf.MODEL_ARCH.PERSIMMON
def set_gguf_parameters(self):
block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
head_count = self.hparams["num_attention_heads"]
@@ -1063,7 +1016,10 @@ class PersimmonModel(Model):
self.gguf_writer.add_tensor(new_name, data)
@Model.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM")
class StableLMModel(Model):
model_arch = gguf.MODEL_ARCH.STABLELM
def set_vocab(self):
if (self.dir_model / "tokenizer.json").is_file():
self._set_vocab_gpt2()
@@ -1087,12 +1043,18 @@ class StableLMModel(Model):
self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"]))
@Model.register("MixtralForCausalLM")
class MixtralModel(Model):
model_arch = gguf.MODEL_ARCH.LLAMA
def set_vocab(self):
self._set_vocab_sentencepiece()
@Model.register("MiniCPMForCausalLM")
class MiniCPMModel(Model):
model_arch = gguf.MODEL_ARCH.MINICPM
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
self.gguf_writer.add_name("MiniCPM")
@@ -1169,7 +1131,10 @@ class MiniCPMModel(Model):
self.gguf_writer.add_tensor(new_name, data)
@Model.register("QWenLMHeadModel")
class QwenModel(Model):
model_arch = gguf.MODEL_ARCH.QWEN
@staticmethod
def token_bytes_to_string(b):
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
@@ -1249,7 +1214,15 @@ class QwenModel(Model):
self.gguf_writer.add_tensor(new_name, data)
@Model.register("Qwen2ForCausalLM")
class Qwen2Model(Model):
model_arch = gguf.MODEL_ARCH.QWEN2
@Model.register("GPT2LMHeadModel")
class GPT2Model(Model):
model_arch = gguf.MODEL_ARCH.GPT2
def set_gguf_parameters(self):
self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_block_count(self.hparams["n_layer"])
@@ -1311,7 +1284,10 @@ class GPT2Model(Model):
self.gguf_writer.add_tensor("output.weight", data)
@Model.register("PhiForCausalLM")
class Phi2Model(Model):
model_arch = gguf.MODEL_ARCH.PHI2
def set_gguf_parameters(self):
block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
@@ -1333,7 +1309,10 @@ class Phi2Model(Model):
self.gguf_writer.add_add_bos_token(False)
@Model.register("PlamoForCausalLM")
class PlamoModel(Model):
model_arch = gguf.MODEL_ARCH.PLAMO
def set_vocab(self):
self._set_vocab_sentencepiece()
@@ -1412,7 +1391,10 @@ class PlamoModel(Model):
self.gguf_writer.add_tensor(new_name, data)
@Model.register("CodeShellForCausalLM")
class CodeShellModel(Model):
model_arch = gguf.MODEL_ARCH.CODESHELL
def set_gguf_parameters(self):
block_count = self.hparams["n_layer"]
@@ -1477,7 +1459,10 @@ class CodeShellModel(Model):
print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
@Model.register("InternLM2ForCausalLM")
class InternLM2Model(Model):
model_arch = gguf.MODEL_ARCH.INTERNLM2
def set_vocab(self):
# (TODO): Is there a better way?
# Copy from _set_vocab_sentencepiece, The only difference is that we will treat the character
@@ -1649,7 +1634,10 @@ in chat mode so that the conversation can end normally.")
self.post_write_tensors(tensor_map, name, data_torch)
@Model.register("BertModel")
class BertModel(Model):
model_arch = gguf.MODEL_ARCH.BERT
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.vocab_size = None
@@ -1659,16 +1647,17 @@ class BertModel(Model):
self.gguf_writer.add_causal_attention(False)
# get pooling path
with open(self.dir_model / "modules.json", encoding="utf-8") as f:
modules = json.load(f)
pooling_path = None
for mod in modules:
if mod["type"] == "sentence_transformers.models.Pooling":
pooling_path = mod["path"]
break
module_path = self.dir_model / "modules.json"
if module_path.is_file():
with open(module_path, encoding="utf-8") as f:
modules = json.load(f)
for mod in modules:
if mod["type"] == "sentence_transformers.models.Pooling":
pooling_path = mod["path"]
break
# get pooling type
pooling_type = gguf.PoolingType.NONE
if pooling_path is not None:
with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
pooling = json.load(f)
@@ -1678,8 +1667,7 @@ class BertModel(Model):
pooling_type = gguf.PoolingType.CLS
else:
raise NotImplementedError("Only MEAN and CLS pooling types supported")
self.gguf_writer.add_pooling_type(pooling_type.value)
self.gguf_writer.add_pooling_type(pooling_type)
def set_vocab(self):
path = self.dir_model
@@ -1755,7 +1743,10 @@ class BertModel(Model):
self.gguf_writer.add_tensor(new_name, data)
@Model.register("NomicBertModel")
class NomicBertModel(BertModel):
model_arch = gguf.MODEL_ARCH.NOMIC_BERT
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -1792,7 +1783,10 @@ class NomicBertModel(BertModel):
yield name, data
@Model.register("GemmaForCausalLM")
class GemmaModel(Model):
model_arch = gguf.MODEL_ARCH.GEMMA
def set_vocab(self):
self._set_vocab_sentencepiece()
@@ -1848,6 +1842,11 @@ class GemmaModel(Model):
self.gguf_writer.add_tensor(new_name, data)
@Model.register("Starcoder2ForCausalLM")
class StarCoder2Model(Model):
model_arch = gguf.MODEL_ARCH.STARCODER2
###### CONVERSION LOGIC ######

View File

@@ -373,7 +373,7 @@ def handle_metadata(cfg, hp):
raise ValueError('Unable to load metadata')
vocab_path = Path(cfg.vocab_dir if cfg.vocab_dir is not None else cfg.model_metadata_dir)
vocab_factory = convert.VocabFactory(vocab_path)
vocab, special_vocab = vocab_factory.load_vocab(cfg.vocabtype, cfg.model_metadata_dir)
vocab, special_vocab = vocab_factory.load_vocab(cfg.vocabtype.split(","), cfg.model_metadata_dir)
convert.check_vocab_size(params, vocab)
return params, vocab, special_vocab
@@ -398,8 +398,8 @@ def handle_args():
help ='Load HuggingFace/.pth vocab and metadata from the specified directory')
parser.add_argument("--vocab-dir", type=Path,
help="directory containing tokenizer.model, if separate from model file - only meaningful with --model-metadata-dir")
parser.add_argument("--vocabtype", choices=["spm", "bpe"], default="spm",
help="vocab format - only meaningful with --model-metadata-dir and/or --vocab-dir (default: spm)")
parser.add_argument("--vocabtype", default="spm,hfft",
help="vocab format - only meaningful with --model-metadata-dir and/or --vocab-dir (default: spm,hfft)")
return parser.parse_args()

View File

@@ -1282,35 +1282,32 @@ def load_some_model(path: Path) -> ModelPlus:
class VocabFactory:
_FILES = {"spm": "tokenizer.model", "bpe": "vocab.json", "hfft": "tokenizer.json"}
def __init__(self, path: Path):
self.path = path
self.files: dict[str, Path | None] = {
"tokenizer.model": None,
"vocab.json": None,
"tokenizer.json": None,
}
self._detect_files()
self.file_paths = self._detect_files()
print(f"Found vocab files: {self.file_paths}")
def _detect_files(self):
for file in self.files.keys():
file_path = self.path / file
parent_file_path = self.path.parent / file
if file_path.exists():
self.files[file] = file_path
elif parent_file_path.exists():
self.files[file] = parent_file_path
print(f"Found vocab files: {self.files}")
def _detect_files(self) -> dict[str, Path | None]:
def locate(file: str) -> Path | None:
if (path := self.path / file).exists():
return path
if (path := self.path.parent / file).exists():
return path
return None
def _select_file(self, vocabtype: str | None) -> Path:
if vocabtype in ["spm", "bpe"]:
for file_key in self.files.keys():
if (file := self.files[file_key]) is not None:
return file
raise FileNotFoundError(f"{vocabtype} vocab not found.")
if vocabtype == "hfft":
# For Hugging Face Fast Tokenizer, return the directory path instead of a specific file
return self.path
raise ValueError(f"Unsupported vocabulary type {vocabtype}")
return {vt: locate(f) for vt, f in self._FILES.items()}
def _select_file(self, vocab_types: list[str]) -> tuple[str, Path]:
for vtype in vocab_types:
try:
path = self.file_paths[vtype]
except KeyError:
raise ValueError(f"Unsupported vocabulary type {vtype}") from None
if path is not None:
return vtype, path
raise FileNotFoundError(f"Could not find any of {[self._FILES[vt] for vt in vocab_types]}")
def _create_special_vocab(self, vocab: Vocab, vocabtype: str, model_parent_path: Path) -> gguf.SpecialVocab:
load_merges = vocabtype == "bpe"
@@ -1322,30 +1319,30 @@ class VocabFactory:
n_vocab=n_vocab,
)
def load_vocab(self, vocabtype: str, model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]:
path = self._select_file(vocabtype)
print(f"Loading vocab file '{path}', type '{vocabtype}'")
def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]:
vocab_type, path = self._select_file(vocab_types)
print(f"Loading vocab file {path!r}, type {vocab_type!r}")
added_tokens_path = path.parent / "added_tokens.json"
vocab: Vocab
if vocabtype == "bpe":
if vocab_type == "bpe":
vocab = BpeVocab(
path, added_tokens_path if added_tokens_path.exists() else None
)
elif vocabtype == "spm":
elif vocab_type == "spm":
vocab = SentencePieceVocab(
path, added_tokens_path if added_tokens_path.exists() else None
)
elif vocabtype == "hfft":
elif vocab_type == "hfft":
vocab = HfVocab(
path, added_tokens_path if added_tokens_path.exists() else None
path.parent, added_tokens_path if added_tokens_path.exists() else None
)
else:
raise ValueError(f"Unsupported vocabulary type {vocabtype}")
raise ValueError(vocab_type)
# FIXME: Respect --vocab-dir?
special_vocab = self._create_special_vocab(
vocab,
vocabtype,
vocab_type,
model_parent_path,
)
return vocab, special_vocab
@@ -1379,15 +1376,14 @@ def main(args_in: list[str] | None = None) -> None:
if np.uint32(1) == np.uint32(1).newbyteorder("<"):
# We currently only support Q8_0 output on little endian systems.
output_choices.append("q8_0")
vocab_types = ["spm", "bpe", "hfft"]
parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file")
parser = argparse.ArgumentParser(description="Convert a LLaMA model to a GGML compatible file")
parser.add_argument("--awq-path", type=Path, help="Path to scale awq cache file", default=None)
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
parser.add_argument("--vocab-type", choices=vocab_types, help="The vocabulary format used to define the tokenizer model (default: spm)", default="spm")
parser.add_argument("--vocab-type", help="vocab types to try in order, choose from 'spm', 'bpe', 'hfft' (default: spm,hfft)", default="spm,hfft")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
@@ -1448,7 +1444,7 @@ def main(args_in: list[str] | None = None) -> None:
model_parent_path = model_plus.paths[0].parent
vocab_path = Path(args.vocab_dir or args.model or model_parent_path)
vocab_factory = VocabFactory(vocab_path)
vocab, special_vocab = vocab_factory.load_vocab(args.vocab_type, model_parent_path)
vocab, special_vocab = vocab_factory.load_vocab(args.vocab_type.split(","), model_parent_path)
if args.vocab_only:
if not args.outfile:

View File

@@ -20,6 +20,7 @@ else()
add_subdirectory(convert-llama2c-to-ggml)
add_subdirectory(embedding)
add_subdirectory(finetune)
add_subdirectory(gritlm)
add_subdirectory(infill)
add_subdirectory(llama-bench)
add_subdirectory(llava)

View File

@@ -19,11 +19,11 @@ static std::vector<std::string> split_lines(const std::string & s) {
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
for (size_t i = 0; i < tokens.size(); i++) {
llama_batch_add(batch, tokens[i], i, { seq_id }, false);
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
}
}
static void normalize(float * vec, float * out, int n) {
static void normalize(const float * vec, float * out, int n) {
float norm = 0;
for (int i = 0; i < n; i++) {
norm += vec[i] * vec[i];
@@ -45,10 +45,23 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}
// normalize on copy
for (int k = 0; k < n_seq; k++) {
float * emb = llama_get_embeddings_ith(ctx, k);
float * out = output + k * n_embd;
normalize(emb, out, n_embd);
for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
continue;
}
// try to get sequence embeddings - supported only when pooling_type is not NONE
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) {
fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
continue;
}
}
float * out = output + batch.seq_id[i][0] * n_embd;
normalize(embd, out, n_embd);
}
}
@@ -132,7 +145,7 @@ int main(int argc, char ** argv) {
// initialize batch
const int n_prompts = prompts.size();
struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts);
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
// allocate output
const int n_embd = llama_n_embd(model);
@@ -145,6 +158,7 @@ int main(int argc, char ** argv) {
for (int k = 0; k < n_prompts; k++) {
// clamp to n_batch tokens
auto & inp = inputs[k];
const uint64_t n_toks = inp.size();
// encode if at capacity

View File

@@ -0,0 +1,5 @@
set(TARGET gritlm)
add_executable(${TARGET} gritlm.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

243
examples/gritlm/gritlm.cpp Normal file
View File

@@ -0,0 +1,243 @@
#include "common.h"
#include "llama.h"
#include <string>
#include <vector>
// #define GRIT_DEBUG
static float dot_product(const std::vector<float> & v1, const std::vector<float> & v2) {
float dot = 0.0f;
for (uint64_t i = 0; i < v1.size(); ++i) {
dot += v1[i] * v2[i];
}
return dot;
}
static float norm(const std::vector<float> & v) {
return std::sqrt(dot_product(v, v));
}
static float cosine_similarity(const std::vector<float> & v1, const std::vector<float> & v2) {
return dot_product(v1, v2) / (norm(v1) * norm(v2));
}
static void normalize(const std::vector<float> & in, float * out) {
float inorm = norm(in);
for (uint64_t i = 0; i < in.size(); i++) {
out[i] = in[i] / inorm;
}
}
static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
auto result = std::vector<std::vector<float>>{};
auto mdl = llama_get_model(ctx);
auto batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
for (uint64_t i = 0; i < sentences.size(); i++) {
llama_batch_clear(batch);
std::string input_string = instruction + sentences[i];
std::vector<llama_token> inputs = llama_tokenize(mdl, input_string, true, false);
auto n_toks = (int32_t)inputs.size();
// GritLM seems to have embed EOS = ""
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
// inputs.push_back(llama_token_eos(mdl));
// we want to ignore instruction tokens for mean pooling
std::vector<llama_token> inputs_instruct = llama_tokenize(mdl, instruction, true, false);
auto n_inst = (int32_t)inputs_instruct.size();
#ifdef GRIT_DEBUG
// debug tokens - should be matching as referenced in the GritLM sample
std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) {
std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str());
});
std::printf("\n");
#endif
// add input to batch (this increments n_tokens)
for (int32_t j = 0; j < n_toks; j++) {
llama_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
}
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
// run model
llama_decode(ctx, batch);
// get embedding dimensions
uint64_t n_embd = llama_n_embd(mdl);
// allocate embedding output
std::vector<float> emb_unorm(n_embd, 0.0f);
// sum up all token embeddings
for (int32_t k = n_inst; k < n_toks; k++) {
float * emb = llama_get_embeddings_ith(ctx, k);
for (uint64_t j = 0; j < n_embd; j++) {
emb_unorm[j] += emb[j];
}
}
// divide by number of tokens (mean pooling)
uint64_t n_sent = n_toks - n_inst;
for (uint64_t j = 0; j < n_embd; j++) {
emb_unorm[j] /= n_sent;
}
auto emb_norm = std::vector<float>(emb_unorm.size());
normalize(emb_unorm, emb_norm.data());
result.push_back(emb_norm);
#ifdef GRIT_DEBUG
// print out emb_norm
std::printf("embedding %ld: ", i);
for (uint64_t j = 0; j < n_embd; j++) {
std::printf("%.5f ", emb_norm[j]);
}
std::printf("\n\n");
#endif
}
llama_batch_free(batch);
return result;
}
static std::string aggregate_pieces(const std::vector<std::string> & pieces) {
// calculate total length required
size_t length = 0;
for (const auto & str : pieces) {
length += str.size();
}
// reserve memory
std::string result;
result.reserve(length);
// append pieces
for (const auto & str : pieces) {
result += str;
}
return result;
}
static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) {
std::vector<std::string> pieces;
const llama_model * mdl = llama_get_model(ctx);
llama_token eos_token = llama_token_eos(mdl);
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
int32_t i_current_token = 0;
while (true) {
llama_batch_clear(bat);
for (auto i = 0; i < inputs.size(); i++) {
llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == inputs.size() - 1);
}
inputs.clear();
llama_decode(ctx, bat);
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
for (auto token = 0; token < candidates.size(); token++) {
candidates[token] = llama_token_data{ token, logits[token], 0.0f };
}
auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };
llama_token token = llama_sample_token_greedy(ctx, &candidates_p);
if (token == eos_token) {
break;
}
std::string piece = llama_token_to_piece(ctx, token);
if (stream) {
std::printf("%s", piece.c_str());
}
pieces.push_back(piece);
inputs.push_back(token);
}
llama_batch_free(bat);
return aggregate_pieces(pieces);
}
static std::string gritlm_instruction(const std::string & instruction) {
return !instruction.empty() ? "<|user|>\n" + instruction + "\n<|embed|>\n" : "<|embed|>\n";
}
int main(int argc, char * argv[])
{
gpt_params params;
if (!gpt_params_parse(argc, argv, params)) {
return 1;
}
llama_model_params mparams = llama_model_params_from_gpt_params(params);
llama_context_params cparams = llama_context_params_from_gpt_params(params);
llama_backend_init();
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
// create new context - set to embedding mode
llama_context * embd_ctx = llama_new_context_with_model(mdl, cparams);
llama_set_embeddings(embd_ctx, true);
// create new context - default mode is causal
llama_context * causal_ctx = llama_new_context_with_model(mdl, cparams);
// samples taken from here: https://github.com/ContextualAI/gritlm#basic
// Embedding/Representation
{
std::string instruction = "Given a scientific paper title, retrieve the paper's abstract";
std::vector<std::string> queries = {
"Bitcoin: A Peer-to-Peer Electronic Cash System",
"Generative Representational Instruction Tuning",
};
std::vector<std::string> documents = {
"A purely peer-to-peer version of electronic cash would allow online payments to be sent directly from one party to another without going through a financial institution. Digital signatures provide part of the solution, but the main benefits are lost if a trusted third party is still required to prevent double-spending. We propose a solution to the double-spending problem using a peer-to-peer network. The network timestamps transactions by hashing them into an ongoing chain of hash-based proof-of-work, forming a record that cannot be changed without redoing the proof-of-work. The longest chain not only serves as proof of the sequence of events witnessed, but proof that it came from the largest pool of CPU power. As long as a majority of CPU power is controlled by nodes that are not cooperating to attack the network, they'll generate the longest chain and outpace attackers. The network itself requires minimal structure. Messages are broadcast on a best effort basis, and nodes can leave and rejoin the network at will, accepting the longest proof-of-work chain as proof of what happened while they were gone.",
"All text-based language problems can be reduced to either generation or embedding. Current models only perform well at one or the other. We introduce generative representational instruction tuning (GRIT) whereby a large language model is trained to handle both generative and embedding tasks by distinguishing between them through instructions. Compared to other open models, our resulting GritLM 7B sets a new state of the art on the Massive Text Embedding Benchmark (MTEB) and outperforms all models up to its size on a range of generative tasks. By scaling up further, GritLM 8X7B outperforms all open generative language models that we tried while still being among the best embedding models. Notably, we find that GRIT matches training on only generative or embedding data, thus we can unify both at no performance loss. Among other benefits, the unification via GRIT speeds up Retrieval-Augmented Generation (RAG) by > 60% for long documents, by no longer requiring separate retrieval and generation models. Models, code, etc. are freely available at https://github.com/ContextualAI/gritlm.",
};
// No need to add instruction for retrieval documents
std::vector<std::vector<float>> d_rep = encode(embd_ctx, documents, gritlm_instruction(""));
std::vector<std::vector<float>> q_rep = encode(embd_ctx, queries, gritlm_instruction(instruction));
float cosine_sim_q0_d0 = cosine_similarity(q_rep[0], d_rep[0]);
float cosine_sim_q0_d1 = cosine_similarity(q_rep[0], d_rep[1]);
float cosine_sim_q1_d0 = cosine_similarity(q_rep[1], d_rep[0]);
float cosine_sim_q1_d1 = cosine_similarity(q_rep[1], d_rep[1]);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[0].c_str(), cosine_sim_q0_d0);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[1].c_str(), cosine_sim_q0_d1);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[0].c_str(), cosine_sim_q1_d0);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1);
}
// Generation
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
{
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
std::string response = generate(causal_ctx, prompt, true);
}
llama_free(embd_ctx);
llama_free(causal_ctx);
llama_free_model(mdl);
llama_backend_free();
return 0;
}

View File

@@ -378,10 +378,10 @@ int main(int argc, char ** argv) {
if (params.interactive) {
const char *control_message;
if (params.multiline_input) {
control_message = " - To return control to LLaMa, end your input with '\\'.\n"
control_message = " - To return control to LLaMA, end your input with '\\'.\n"
" - To return control without starting a new line, end your input with '/'.\n";
} else {
control_message = " - Press Return to return control to LLaMa.\n"
control_message = " - Press Return to return control to LLaMA.\n"
" - To return control without starting a new line, end your input with '/'.\n"
" - If you want to submit another line, end your input with '\\'.\n";
}

View File

@@ -511,6 +511,14 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;
// tokenized antiprompts
std::vector<std::vector<llama_token>> antiprompt_ids;
antiprompt_ids.reserve(params.antiprompt.size());
for (const std::string & antiprompt : params.antiprompt) {
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
}
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
@@ -769,6 +777,18 @@ int main(int argc, char ** argv) {
}
}
// check for reverse prompt using special tokens
llama_token last_token = llama_sampling_last(ctx_sampling);
for (std::vector<llama_token> ids : antiprompt_ids) {
if (ids.size() == 1 && last_token == ids[0]) {
if (params.interactive) {
is_interacting = true;
}
is_antiprompt = true;
break;
}
}
if (is_antiprompt) {
LOG("found antiprompt: %s\n", last_output.c_str());
}

34
examples/server-embd.py Normal file
View File

@@ -0,0 +1,34 @@
import asyncio
import requests
import numpy as np
n = 8
result = []
async def requests_post_async(*args, **kwargs):
return await asyncio.to_thread(requests.post, *args, **kwargs)
async def main():
model_url = "http://127.0.0.1:6900"
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
url= f"{model_url}/embedding",
json= {"content": str(i)*1024}
) for i in range(n)])
for response in responses:
embedding = response.json()["embedding"]
print(embedding[-8:])
result.append(embedding)
asyncio.run(main())
# compute cosine similarity
for i in range(n-1):
for j in range(i+1, n):
embedding1 = np.array(result[i])
embedding2 = np.array(result[j])
similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
print(f"Similarity between {i} and {j}: {similarity:.2f}")

View File

@@ -18,7 +18,7 @@ The project is under active development, and we are [looking for feedback and co
- `--threads N`, `-t N`: Set the number of threads to use during generation.
- `-tb N, --threads-batch N`: Set the number of threads to use during batch and prompt processing. If not specified, the number of threads will be set to the number of threads used for generation.
- `--threads-http N`: number of threads in the http server pool to process requests (default: `std::thread::hardware_concurrency()`)
- `--threads-http N`: number of threads in the http server pool to process requests (default: `max(std::thread::hardware_concurrency() - 1, --parallel N + 2)`)
- `-m FNAME`, `--model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.gguf`).
- `-a ALIAS`, `--alias ALIAS`: Set an alias for the model. The alias will be returned in API responses.
- `-c N`, `--ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference. The size may differ in other models, for example, baichuan models were build with a context of 4096.

View File

@@ -413,7 +413,7 @@ struct llama_server_context
int res = llama_chat_apply_template(model, nullptr, chat, 1, true, buf.data(), buf.size());
if (res < 0) {
LOG_ERROR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
sparams.chat_template = "<|im_start|>"; // llama_chat_apply_template only checks if <|im_start|> exist in the template
sparams.chat_template = "chatml";
}
}
@@ -441,8 +441,8 @@ struct llama_server_context
const int ga_w = params.grp_attn_w;
if (ga_n != 1) {
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
@@ -1210,7 +1210,7 @@ struct llama_server_context
queue_results.send(res);
}
void send_embedding(server_slot &slot)
void send_embedding(server_slot & slot, const llama_batch & batch)
{
task_result res;
res.id = slot.task_id;
@@ -1219,6 +1219,7 @@ struct llama_server_context
res.stop = true;
const int n_embd = llama_n_embd(model);
if (!params.embedding)
{
LOG_WARNING("embedding disabled", {{"params.embedding", params.embedding}});
@@ -1229,12 +1230,29 @@ struct llama_server_context
}
else
{
const float *data = llama_get_embeddings(ctx);
std::vector<float> embedding(data, data + n_embd);
res.result_json = json
{
{"embedding", embedding},
};
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
continue;
}
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) {
LOG_ERROR("failed to get embeddings for token", {{"token", batch.token[i]}, {"seq_id", batch.seq_id[i][0]}});
res.result_json = json
{
{"embedding", std::vector<float>(n_embd, 0.0f)},
};
continue;
}
}
res.result_json = json
{
{"embedding", std::vector<float>(embd, embd + n_embd)},
};
}
}
queue_results.send(res);
}
@@ -1709,8 +1727,8 @@ struct llama_server_context
}
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
// if input prompt is too big, truncate it
if (slot.n_prompt_tokens >= slot.n_ctx)
// if input prompt is too big, truncate it, if group attention self-extend is disabled
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx)
{
const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_block_size = n_left / 2;
@@ -1785,9 +1803,11 @@ struct llama_server_context
}
LOG_INFO("slot progression", {
{ "slot_id", slot.id },
{ "task_id", slot.task_id },
{ "n_past", slot.n_past },
{ "slot_id", slot.id },
{ "task_id", slot.task_id },
{ "n_past", slot.n_past },
{ "n_past_se", slot.n_past_se },
{ "ga_i", slot.ga_i },
{ "n_prompt_tokens_processed", slot.n_prompt_tokens_processed }
});
}
@@ -1843,7 +1863,7 @@ struct llama_server_context
ga_i += ga_w/ga_n;
}
}
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
slot_npast++;
}
@@ -1879,7 +1899,7 @@ struct llama_server_context
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
{
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
for (auto & slot : slots)
{
@@ -1952,7 +1972,7 @@ struct llama_server_context
// prompt evaluated for embedding
if (slot.embedding)
{
send_embedding(slot);
send_embedding(slot, batch_view);
slot.release();
slot.i_batch = -1;
continue;
@@ -2001,6 +2021,17 @@ struct llama_server_context
LOG_VERBOSE("slots updated", {});
return true;
}
json model_meta() {
return json{
{"vocab_type", llama_vocab_type(model)},
{"n_vocab", llama_n_vocab(model)},
{"n_ctx_train", llama_n_ctx_train(model)},
{"n_embd", llama_n_embd(model)},
{"n_params", llama_model_n_params(model)},
{"size", llama_model_size(model)},
};
}
};
static void server_print_usage(const char *argv0, const gpt_params &params,
@@ -2013,7 +2044,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
printf(" -tb N, --threads-batch N number of threads to use during batch and prompt processing (default: same as --threads)\n");
printf(" --threads-http N number of threads in the http server pool to process requests (default: hardware concurrency)\n");
printf(" --threads-http N number of threads in the http server pool to process requests (default: max(hardware concurrency - 1, --parallel N + 2))\n");
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
printf(" --rope-scaling {none,linear,yarn}\n");
printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
@@ -2023,6 +2054,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" --pooling {none,mean,cls}\n");
printf(" pooling type for embeddings, use model default if unspecified\n");
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
@@ -2263,6 +2296,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.yarn_beta_slow = std::stof(argv[i]);
}
else if (arg == "--pooling")
{
if (++i >= argc) {
invalid_param = true;
break;
}
std::string value(argv[i]);
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else { invalid_param = true; break; }
}
else if (arg == "--threads" || arg == "-t")
{
if (++i >= argc)
@@ -2317,7 +2362,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break;
}
params.n_batch = std::stoi(argv[i]);
params.n_batch = std::min(512, params.n_batch);
}
else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers")
{
@@ -2911,9 +2955,10 @@ int main(int argc, char **argv)
for (const auto& metric_def : metrics_def) {
std::string name = metric_def["name"];
std::string help = metric_def["help"];
prometheus << "# HELP llamacpp:" << name << " " << help << "\n"
<< "# TYPE llamacpp:" << name << " " << type << "\n"
<< "llamacpp:" << name << " " << metric_def["value"] << "\n";
auto value = json_value(metric_def, "value", 0);
prometheus << "# HELP llamacpp:" << name << " " << help << "\n"
<< "# TYPE llamacpp:" << name << " " << type << "\n"
<< "llamacpp:" << name << " " << value << "\n";
}
}
@@ -2994,6 +3039,7 @@ int main(int argc, char **argv)
state.store(SERVER_STATE_READY);
LOG_INFO("model loaded", {});
}
const auto model_meta = llama.model_meta();
if (sparams.chat_template.empty()) { // custom chat template is not supplied
// check if the template comes with the model is supported by us
@@ -3143,7 +3189,7 @@ int main(int argc, char **argv)
}
});
svr.Get("/v1/models", [&params](const httplib::Request& req, httplib::Response& res)
svr.Get("/v1/models", [&params, &model_meta](const httplib::Request& req, httplib::Response& res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
std::time_t t = std::time(0);
@@ -3152,10 +3198,11 @@ int main(int argc, char **argv)
{"object", "list"},
{"data", {
{
{"id", params.model_alias},
{"object", "model"},
{"created", t},
{"owned_by", "llamacpp"}
{"id", params.model_alias},
{"object", "model"},
{"created", t},
{"owned_by", "llamacpp"},
{"meta", model_meta}
},
}}
};
@@ -3452,10 +3499,12 @@ int main(int argc, char **argv)
}*/
//);
if (sparams.n_threads_http > 0) {
log_data["n_threads_http"] = std::to_string(sparams.n_threads_http);
svr.new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); };
if (sparams.n_threads_http < 1) {
// +2 threads for monitoring endpoints
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
}
log_data["n_threads_http"] = std::to_string(sparams.n_threads_http);
svr.new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); };
LOG_INFO("HTTP server listening", log_data);
// run the HTTP server in a thread - see comment below

View File

@@ -1,22 +1,30 @@
# Server tests
Python based server tests scenario using [BDD](https://en.wikipedia.org/wiki/Behavior-driven_development) and [behave](https://behave.readthedocs.io/en/latest/):
* [issues.feature](./features/issues.feature) Pending issues scenario
* [parallel.feature](./features/parallel.feature) Scenario involving multi slots and concurrent requests
* [security.feature](./features/security.feature) Security, CORS and API Key
* [server.feature](./features/server.feature) Server base scenario: completion, embedding, tokenization, etc...
Python based server tests scenario using [BDD](https://en.wikipedia.org/wiki/Behavior-driven_development)
and [behave](https://behave.readthedocs.io/en/latest/):
* [issues.feature](./features/issues.feature) Pending issues scenario
* [parallel.feature](./features/parallel.feature) Scenario involving multi slots and concurrent requests
* [security.feature](./features/security.feature) Security, CORS and API Key
* [server.feature](./features/server.feature) Server base scenario: completion, embedding, tokenization, etc...
Tests target GitHub workflows job runners with 4 vCPU.
Requests are using [aiohttp](https://docs.aiohttp.org/en/stable/client_reference.html), [asyncio](https://docs.python.org/fr/3/library/asyncio.html) based http client.
Requests are
using [aiohttp](https://docs.aiohttp.org/en/stable/client_reference.html), [asyncio](https://docs.python.org/fr/3/library/asyncio.html)
based http client.
Note: If the host architecture inference speed is faster than GitHub runners one, parallel scenario may randomly fail. To mitigate it, you can increase values in `n_predict`, `kv_size`.
Note: If the host architecture inference speed is faster than GitHub runners one, parallel scenario may randomly fail.
To mitigate it, you can increase values in `n_predict`, `kv_size`.
### Install dependencies
`pip install -r requirements.txt`
### Run tests
1. Build the server
```shell
cd ../../..
mkdir build
@@ -24,24 +32,36 @@ cd build
cmake ../
cmake --build . --target server
```
2. download required models:
1. `../../../scripts/hf.sh --repo ggml-org/models --file tinyllamas/stories260K.gguf`
3. Start the test: `./tests.sh`
2. Start the test: `./tests.sh`
It's possible to override some scenario steps values with environment variables:
- `PORT` -> `context.server_port` to set the listening port of the server during scenario, default: `8080`
- `LLAMA_SERVER_BIN_PATH` -> to change the server binary path, default: `../../../build/bin/server`
- `DEBUG` -> "ON" to enable steps and server verbose mode `--verbose`
- `SERVER_LOG_FORMAT_JSON` -> if set switch server logs to json format
| variable | description |
|--------------------------|------------------------------------------------------------------------------------------------|
| `PORT` | `context.server_port` to set the listening port of the server during scenario, default: `8080` |
| `LLAMA_SERVER_BIN_PATH` | to change the server binary path, default: `../../../build/bin/server` |
| `DEBUG` | "ON" to enable steps and server verbose mode `--verbose` |
| `SERVER_LOG_FORMAT_JSON` | if set switch server logs to json format |
| `N_GPU_LAYERS` | number of model layers to offload to VRAM `-ngl --n-gpu-layers` |
### Run @bug, @wip or @wrong_usage annotated scenario
Feature or Scenario must be annotated with `@llama.cpp` to be included in the default scope.
- `@bug` annotation aims to link a scenario with a GitHub issue.
- `@wrong_usage` are meant to show user issue that are actually an expected behavior
- `@wip` to focus on a scenario working in progress
- `@slow` heavy test, disabled by default
To run a scenario annotated with `@bug`, start:
`DEBUG=ON ./tests.sh --no-skipped --tags bug`
```shell
DEBUG=ON ./tests.sh --no-skipped --tags bug
```
After changing logic in `steps.py`, ensure that `@bug` and `@wrong_usage` scenario are updated.
```shell
./tests.sh --no-skipped --tags bug,wrong_usage || echo "should failed but compile"
```

View File

@@ -7,7 +7,10 @@ from signal import SIGKILL
def before_scenario(context, scenario):
print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m")
context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON'
if context.debug:
print("DEBUG=ON\n")
print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m\n")
port = 8080
if 'PORT' in os.environ:
port = int(os.environ['PORT'])

View File

@@ -1,4 +1,5 @@
# List of ongoing issues
# run with: DEBUG=ON ./tests.sh --no-skipped --tags bug
@bug
Feature: Issues
# No confirmed issue at the moment

View File

@@ -1,11 +1,12 @@
@llama.cpp
@parallel
Feature: Parallel
Background: Server startup
Given a server listening on localhost:8080
And a model file stories260K.gguf
And a model alias tinyllama-2
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And 42 as server seed
And 512 as batch size
And 64 KV cache size
And 2 slots
And embeddings extraction

View File

@@ -0,0 +1,55 @@
# run with: ./tests.sh --no-skipped --tags passkey
@passkey
@slow
Feature: Passkey / Self-extend with context shift
Background: Server startup
Given a server listening on localhost:8080
# Generates a long text of junk and inserts a secret passkey number inside it.
# Then we query the LLM for the secret passkey.
# see #3856 and #4810
Scenario Outline: Passkey
Given a model file <hf_file> from HF repo <hf_repo>
And <n_batch> as batch size
And <n_junk> as number of junk
And <n_predicted> server max tokens to predict
And 42 as seed
And <n_ctx> KV cache size
And 1 slots
And <n_ga> group attention factor to extend context size through self-extend
And <n_ga_w> group attention width to extend context size through self-extend
# Can be override with N_GPU_LAYERS
And <ngl> GPU offloaded layers
Then the server is starting
Then the server is healthy
Given available models
Then model 0 is trained on <n_ctx_train> tokens context
Given a prefix prompt:
"""
here is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.
"""
And a passkey prompt template:
"""
The pass key is <passkey> Remember it. <passkey> is the pass key.
"""
And a junk suffix prompt:
"""
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.
"""
And a suffix prompt:
"""
What is the pass key? The pass key is
"""
Given a "<passkey>" passkey challenge prompt with the passkey inserted every <i_pos> junk
And a completion request with no api error
Then <n_predicted> tokens are predicted matching <re_content>
Examples:
| hf_repo | hf_file | n_ctx_train | ngl | n_ctx | n_batch | n_ga | n_ga_w | n_junk | i_pos | passkey | n_predicted | re_content |
| TheBloke/phi-2-GGUF | phi-2.Q4_K_M.gguf | 2048 | 5 | 8192 | 512 | 4 | 512 | 250 | 50 | 42 | 1 | 42 |
| TheBloke/phi-2-GGUF | phi-2.Q4_K_M.gguf | 2048 | 5 | 8192 | 512 | 2 | 512 | 250 | 50 | 42 | 1 | \b((?!42)\w)+\b |
#| TheBloke/Llama-2-7B-GGUF | llama-2-7b.Q2_K.gguf | 4096 | 3 | 16384 | 512 | 4 | 512 | 500 | 300 | 1234 | 5 | 1234 |
#| TheBloke/Mixtral-8x7B-v0.1-GGUF | mixtral-8x7b-v0.1.Q2_K.gguf | 32768 | 2 | 16384 | 512 | 4 | 512 | 500 | 100 | 0987 | 5 | 0
# 987 |

View File

@@ -1,9 +1,10 @@
@llama.cpp
@security
Feature: Security
Background: Server startup with an api key defined
Given a server listening on localhost:8080
And a model file stories260K.gguf
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a server api key llama.cpp
Then the server is starting
Then the server is healthy

View File

@@ -1,15 +1,17 @@
@llama.cpp
@server
Feature: llama.cpp server
Background: Server startup
Given a server listening on localhost:8080
And a model file stories260K.gguf
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a model alias tinyllama-2
And 42 as server seed
# KV Cache corresponds to the total amount of tokens
# that can be stored across all independent sequences: #4130
# see --ctx-size and #5568
And 32 KV cache size
And 512 as batch size
And 1 slots
And embeddings extraction
And 32 server max tokens to predict
@@ -29,9 +31,9 @@ Feature: llama.cpp server
And prometheus metrics are exposed
Examples: Prompts
| prompt | n_predict | re_content | n_predicted |
| I believe the meaning of life is | 8 | (read<or>going)+ | 8 |
| Write a joke about AI | 64 | (park<or>friends<or>scared<or>always)+ | 32 |
| prompt | n_predict | re_content | n_predicted |
| I believe the meaning of life is | 8 | (read\|going)+ | 8 |
| Write a joke about AI | 64 | (park\|friends\|scared\|always)+ | 32 |
Scenario Outline: OAI Compatibility
Given a model <model>
@@ -43,9 +45,9 @@ Feature: llama.cpp server
Then <n_predicted> tokens are predicted matching <re_content>
Examples: Prompts
| model | system_prompt | user_prompt | max_tokens | re_content | n_predicted | enable_streaming |
| llama-2 | Book | What is the best book | 8 | (Mom<or>what)+ | 8 | disabled |
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | (thanks<or>happy<or>bird)+ | 32 | enabled |
| model | system_prompt | user_prompt | max_tokens | re_content | n_predicted | enable_streaming |
| llama-2 | Book | What is the best book | 8 | (Mom\|what)+ | 8 | disabled |
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | (thanks\|happy\|bird)+ | 32 | enabled |
Scenario: Embedding
When embeddings are computed for:
@@ -75,10 +77,15 @@ Feature: llama.cpp server
When an OAI compatible embeddings computation request for multiple inputs
Then embeddings are generated
Scenario: Tokenize / Detokenize
When tokenizing:
"""
What is the capital of France ?
"""
Then tokens can be detokenize
Scenario: Models available
Given available models
Then 1 models are supported
Then model 0 is identified by tinyllama-2
Then model 0 is trained on 128 tokens context

View File

@@ -13,6 +13,7 @@ import aiohttp
import openai
from behave import step
from behave.api.async_step import async_run_until_complete
from huggingface_hub import hf_hub_download
from prometheus_client import parser
@@ -26,17 +27,23 @@ def step_server_config(context, server_fqdn, server_port):
context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON'
context.model_alias = None
context.n_batch = None
context.n_ctx = None
context.n_ga = None
context.n_ga_w = None
context.n_gpu_layer = None
context.n_predict = None
context.n_server_predict = None
context.n_slots = None
context.prompt_prefix = None
context.prompt_suffix = None
context.server_api_key = None
context.server_continuous_batching = False
context.server_embeddings = False
context.server_metrics = False
context.server_process = None
context.seed = None
context.server_seed = None
context.user_api_key = None
@@ -45,9 +52,11 @@ def step_server_config(context, server_fqdn, server_port):
context.prompts = []
@step(u'a model file {model_file}')
def step_model_file(context, model_file):
context.model_file = model_file
@step(u'a model file {hf_file} from HF repo {hf_repo}')
def step_download_hf_model(context, hf_file, hf_repo):
context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file)
if context.debug:
print(f"model file: {context.model_file}\n")
@step(u'a model alias {model_alias}')
@@ -55,24 +64,34 @@ def step_model_alias(context, model_alias):
context.model_alias = model_alias
@step(u'{seed} as server seed')
@step(u'{seed:d} as server seed')
def step_seed(context, seed):
context.server_seed = int(seed)
context.server_seed = seed
@step(u'{n_ctx} KV cache size')
@step(u'{ngl:d} GPU offloaded layers')
def step_n_gpu_layer(context, ngl):
if 'N_GPU_LAYERS' in os.environ:
new_ngl = int(os.environ['N_GPU_LAYERS'])
if context.debug:
print(f"-ngl upgraded from {ngl} to {new_ngl}")
ngl = new_ngl
context.n_gpu_layer = ngl
@step(u'{n_ctx:d} KV cache size')
def step_n_ctx(context, n_ctx):
context.n_ctx = int(n_ctx)
context.n_ctx = n_ctx
@step(u'{n_slots} slots')
@step(u'{n_slots:d} slots')
def step_n_slots(context, n_slots):
context.n_slots = int(n_slots)
context.n_slots = n_slots
@step(u'{n_predict} server max tokens to predict')
@step(u'{n_predict:d} server max tokens to predict')
def step_server_n_predict(context, n_predict):
context.n_server_predict = int(n_predict)
context.n_server_predict = n_predict
@step(u'continuous batching')
@@ -116,11 +135,13 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status):
case 'ready' | 'idle':
await wait_for_health_status(context, context.base_url, 200, 'ok',
timeout=10,
params={'fail_on_no_slot': 0, 'include_slots': 0},
slots_idle=context.n_slots,
slots_processing=0,
expected_slots=[{'id': slot_id, 'state': 0}
for slot_id in range(context.n_slots)])
for slot_id in
range(context.n_slots if context.n_slots else 1)])
case 'busy':
await wait_for_health_status(context, context.base_url, 503,
'no slot available',
@@ -128,7 +149,8 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status):
slots_idle=0,
slots_processing=context.n_slots,
expected_slots=[{'id': slot_id, 'state': 1}
for slot_id in range(context.n_slots)])
for slot_id in
range(context.n_slots if context.n_slots else 1)])
case _:
assert False, "unknown status"
@@ -157,24 +179,24 @@ async def step_request_completion(context, api_error):
context.base_url,
debug=context.debug,
n_predict=context.n_predict,
server_seed=context.server_seed,
seed=await completions_seed(context),
expect_api_error=expect_api_error,
user_api_key=context.user_api_key)
context.tasks_result.append(completion)
if context.debug:
print(f"Completion response: {completion}")
print(f"Completion response: {completion}\n")
if expect_api_error:
assert completion == 401, f"completion must be an 401 status code: {completion}"
@step(u'{predicted_n} tokens are predicted matching {re_content}')
@step(u'{predicted_n:d} tokens are predicted matching {re_content}')
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n), re_content)
assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n, re_content)
@step(u'{predicted_n} tokens are predicted')
@step(u'{predicted_n:d} tokens are predicted')
def step_n_tokens_predicted(context, predicted_n):
assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n))
assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n)
@step(u'a user prompt {user_prompt}')
@@ -192,9 +214,9 @@ def step_model(context, model):
context.model = model
@step(u'{max_tokens} max tokens to predict')
@step(u'{max_tokens:d} max tokens to predict')
def step_max_tokens(context, max_tokens):
context.n_predict = int(max_tokens)
context.n_predict = max_tokens
@step(u'streaming is {enable_streaming}')
@@ -222,11 +244,70 @@ def step_server_api_key(context, server_api_key):
context.server_api_key = server_api_key
@step(u'{n_junk:d} as number of junk')
def step_n_junk(context, n_junk):
context.n_junk = n_junk
@step(u'{n_batch:d} as batch size')
def step_n_batch(context, n_batch):
context.n_batch = n_batch
@step(u'{seed:d} as seed')
def step_seed(context, seed):
context.seed = seed
@step(u'a prefix prompt')
def step_prompt_prefix(context):
context.prompt_prefix = context.text
@step(u'a junk suffix prompt')
def step_prompt_junk_suffix(context):
context.prompt_junk_suffix = context.text
@step(u'a suffix prompt')
def step_prompt_suffix(context):
context.prompt_suffix = context.text
@step(u'{n_ga:d} group attention factor'
u' to extend context size through self-extend')
def step_impl(context, n_ga):
context.n_ga = n_ga
@step(u'{n_ga_w:d} group attention width to extend context size through self-extend')
def step_impl(context, n_ga_w):
context.n_ga_w = n_ga_w
@step(u'a passkey prompt template')
def step_prompt_passkey(context):
context.prompt_passkey = context.text
@step(u'a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk')
def step_prompt_passkey(context, passkey, i_pos):
prompt = ""
for i in range(context.n_junk):
if i % context.n_junk == i_pos:
prompt += context.prompt_passkey # the passkey is already substituted
prompt += context.prompt_junk_suffix
if context.debug:
passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m"
print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n")
context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix)
@step(u'an OAI compatible chat completions request with {api_error} api error')
@async_run_until_complete
async def step_oai_chat_completions(context, api_error):
if context.debug:
print(f"Submitting OAI compatible completions request...")
print(f"Submitting OAI compatible completions request...\n")
expect_api_error = api_error == 'raised'
completion = await oai_chat_completions(context.prompts.pop(),
context.system_prompt,
@@ -241,8 +322,7 @@ async def step_oai_chat_completions(context, api_error):
enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None,
server_seed=context.server_seed
if hasattr(context, 'server_seed') else None,
seed=await completions_seed(context),
user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None,
@@ -276,8 +356,10 @@ async def step_concurrent_completion_requests(context):
# prompt is inserted automatically
context.base_url,
debug=context.debug,
prompt_prefix=context.prompt_prefix,
prompt_suffix=context.prompt_suffix,
n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
seed=await completions_seed(context),
user_api_key=context.user_api_key if hasattr(context,
'user_api_key') else None)
@@ -297,8 +379,7 @@ async def step_oai_chat_completions(context):
if hasattr(context, 'n_predict') else None,
enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None,
server_seed=context.server_seed
if hasattr(context, 'server_seed') else None,
seed=await completions_seed(context),
user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None)
@@ -318,7 +399,9 @@ async def step_oai_chat_completions(context):
if hasattr(context, 'n_predict') else None,
enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None,
server_seed=context.server_seed
seed=context.seed
if hasattr(context, 'seed') else
context.server_seed
if hasattr(context, 'server_seed') else None,
user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None)
@@ -330,11 +413,10 @@ async def step_all_prompts_are_predicted(context):
await all_prompts_are_predicted(context)
@step(u'all prompts are predicted with {n_predict} tokens')
@step(u'all prompts are predicted with {n_expected_predicted:d} tokens')
@async_run_until_complete
async def step_all_prompts_are_predicted_with_n_tokens(context, n_predict):
expected_predicted_n = int(n_predict)
await all_prompts_are_predicted(context, expected_predicted_n)
async def step_all_prompts_are_predicted_with_n_tokens(context, n_expected_predicted):
await all_prompts_are_predicted(context, n_expected_predicted)
async def all_prompts_are_predicted(context, expected_predicted_n=None):
@@ -464,6 +546,8 @@ async def step_prometheus_metrics_exported(context):
assert metrics_response.headers['Content-Type'] == "text/plain; version=0.0.4"
metrics_raw = await metrics_response.text()
metric_exported = False
if context.debug:
print(f"/metrics answer:\n{metrics_raw}\n")
for metric in parser.text_string_to_metric_families(metrics_raw):
match metric.name:
case "llamacpp:kv_cache_usage_ratio":
@@ -472,6 +556,37 @@ async def step_prometheus_metrics_exported(context):
assert metric_exported, "No metrics exported"
@step(u'available models')
def step_available_models(context):
# openai client always expects an api_key
openai.api_key = context.user_api_key if context.user_api_key is not None else 'nope'
openai.api_base = f'{context.base_url}/v1'
context.models = openai.Model.list().data
@step(u'{n_model:d} models are supported')
def step_supported_models(context, n_model):
if context.debug:
print("server models available:", context.models)
assert len(context.models) == n_model
@step(u'model {i_model:d} is {param} {preposition} {param_value}')
def step_supported_models(context, i_model, param, preposition, param_value):
assert i_model < len(context.models)
model = context.models[i_model]
param_value = param_value.split(' ', 1)[0]
match param:
case 'identified':
value = model.id
case 'trained':
value = str(model.meta.n_ctx_train)
case _:
assert False, "param {param} not supported"
assert param_value == value, f"model param {param} {value} != {param_value}"
async def concurrent_requests(context, f_completion, *args, **kwargs):
n_prompts = len(context.prompts)
if context.debug:
@@ -486,8 +601,10 @@ async def concurrent_requests(context, f_completion, *args, **kwargs):
async def request_completion(prompt,
base_url,
debug=False,
prompt_prefix=None,
prompt_suffix=None,
n_predict=None,
server_seed=None,
seed=None,
expect_api_error=None,
user_api_key=None):
if debug:
@@ -504,11 +621,14 @@ async def request_completion(prompt,
async with aiohttp.ClientSession() as session:
async with session.post(f'{base_url}/completion',
json={
"input_prefix": prompt_prefix,
"prompt": prompt,
"n_predict": int(n_predict) if n_predict is not None else -1,
"seed": server_seed if server_seed is not None else 42
"input_suffix": prompt_suffix,
"n_predict": n_predict if n_predict is not None else -1,
"seed": seed if seed is not None else 42
},
headers=headers) as response:
headers=headers,
timeout=3600) as response:
if expect_api_error is None or not expect_api_error:
assert response.status == 200
assert response.headers['Access-Control-Allow-Origin'] == origin
@@ -526,14 +646,14 @@ async def oai_chat_completions(user_prompt,
model=None,
n_predict=None,
enable_streaming=None,
server_seed=None,
seed=None,
user_api_key=None,
expect_api_error=None):
if debug:
print(f"Sending OAI Chat completions request: {user_prompt}")
# openai client always expects an api key
user_api_key = user_api_key if user_api_key is not None else 'nope'
seed = server_seed if server_seed is not None else 42
seed = seed if seed is not None else 42
enable_streaming = enable_streaming if enable_streaming is not None else False
payload = {
"messages": [
@@ -692,20 +812,32 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
content = completion_response['content']
n_predicted = completion_response['timings']['predicted_n']
assert len(content) > 0, "no token predicted"
if expected_predicted_n is not None:
if re_content is not None:
p = re.compile(re_content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL)
matches = p.finditer(content)
last_match = 0
highlighted = ''
for match in matches:
start, end = match.span()
highlighted += content[last_match: start]
highlighted += '\x1b[33m'
highlighted += content[start: end]
highlighted += '\x1b[0m'
last_match = end
highlighted += content[last_match:]
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
print(f"Checking completion response: {highlighted}\n")
assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```'
if expected_predicted_n and expected_predicted_n > 0:
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
f' {n_predicted} <> {expected_predicted_n}')
if re_content is not None:
re_content = '^.*' + re_content.replace('<or>', '|') + '.*$'
assert re.match(re_content, content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL), (
f'invalid tokens predicted:'
f' ```\n{content}\n``` do not match /{re_content}/')
async def gather_tasks_results(context):
n_tasks = len(context.concurrent_tasks)
if context.debug:
print(f"Waiting for all {n_tasks} tasks results...")
print(f"Waiting for all {n_tasks} tasks results...\n")
for task_no in range(n_tasks):
context.tasks_result.append(await context.concurrent_tasks.pop())
n_completions = len(context.tasks_result)
@@ -716,15 +848,13 @@ async def wait_for_health_status(context,
base_url,
expected_http_status_code,
expected_health_status,
timeout=3,
params=None,
slots_idle=None,
slots_processing=None,
expected_slots=None):
if context.debug:
print(f"Starting checking for health for expected_health_status={expected_health_status}")
timeout = 3 # seconds
if expected_health_status == 'ok':
timeout = 10 # CI slow inference
print(f"Starting checking for health for expected_health_status={expected_health_status}\n")
interval = 0.5
counter = 0
async with aiohttp.ClientSession() as session:
@@ -734,7 +864,7 @@ async def wait_for_health_status(context,
health = await health_response.json()
if context.debug:
print(f"HEALTH - response for expected health status='{expected_health_status}' on "
f"'{base_url}/health'?{params} is {health}")
f"'{base_url}/health'?{params} is {health}\n")
if (status_code == expected_http_status_code
and health['status'] == expected_health_status
and (slots_idle is None or health['slots_idle'] == slots_idle)
@@ -757,7 +887,7 @@ async def wait_for_health_status(context,
if expected_http_status_code == 503:
if len(context.tasks_result) == 0:
print("\x1b[5;37;43mWARNING: forcing concurrent tasks,"
" busy health check missed, probably too fast inference\x1b[0m")
" busy health check missed, probably too fast inference\x1b[0m\n")
n_completions = await gather_tasks_results(context)
if n_completions > 0:
return
@@ -791,6 +921,11 @@ def assert_slots_status(slots, expected_slots):
f" = {expected[key]} != {slot[key]}")
async def completions_seed(context):
return context.seed if hasattr(context, 'seed') and context.seed is not None \
else context.server_seed if hasattr(context, 'server_seed') else None
def start_server_background(context):
context.server_path = '../../../build/bin/server'
if 'LLAMA_SERVER_BIN_PATH' in os.environ:
@@ -800,27 +935,35 @@ def start_server_background(context):
'--port', context.server_port,
'--model', context.model_file
]
if context.n_batch:
server_args.extend(['--batch-size', context.n_batch])
if context.n_gpu_layer:
server_args.extend(['--n-gpu-layers', context.n_gpu_layer])
if context.server_continuous_batching:
server_args.append('--cont-batching')
if context.server_embeddings:
server_args.append('--embedding')
if context.server_metrics:
server_args.append('--metrics')
if context.model_alias is not None:
if context.model_alias:
server_args.extend(['--alias', context.model_alias])
if context.n_ctx is not None:
if context.n_ctx:
server_args.extend(['--ctx-size', context.n_ctx])
if context.n_slots is not None:
if context.n_slots:
server_args.extend(['--parallel', context.n_slots])
if context.n_server_predict is not None:
if context.n_server_predict:
server_args.extend(['--n-predict', context.n_server_predict])
if context.server_api_key is not None:
if context.server_api_key:
server_args.extend(['--api-key', context.server_api_key])
if context.n_ga:
server_args.extend(['--grp-attn-n', context.n_ga])
if context.n_ga_w:
server_args.extend(['--grp-attn-w', context.n_ga_w])
if context.debug:
server_args.append('--verbose')
if 'SERVER_LOG_FORMAT_JSON' not in os.environ:
server_args.extend(['--log-format', "text"])
print(f"starting server with: {context.server_path}", *server_args)
print(f"starting server with: {context.server_path} {server_args}\n")
context.server_process = subprocess.Popen(
[str(arg) for arg in [context.server_path, *server_args]],
close_fds=True)

View File

@@ -1,4 +1,4 @@
# run with ./test.sh --tags wrong_usage
# run with: ./tests.sh --no-skipped --tags wrong_usage
@wrong_usage
Feature: Wrong usage of llama.cpp server
@@ -7,7 +7,7 @@ Feature: Wrong usage of llama.cpp server
# or pass n_predict/max_tokens in the request.
Scenario: Infinite loop
Given a server listening on localhost:8080
And a model file stories260K.gguf
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
# Uncomment below to fix the issue
#And 64 server max tokens to predict
Then the server is starting
@@ -18,4 +18,5 @@ Feature: Wrong usage of llama.cpp server
# Uncomment below to fix the issue
#And 128 max tokens to predict
Given concurrent completion requests
Then the server is idle
Then all prompts are predicted

View File

@@ -1,4 +1,5 @@
aiohttp~=3.9.3
behave~=1.2.6
huggingface_hub~=0.20.3
openai~=0.25.0
prometheus-client~=0.20.0

View File

@@ -5,7 +5,7 @@ set -eu
if [ $# -lt 1 ]
then
# Start @llama.cpp scenario
behave --summary --stop --no-capture --exclude 'issues|wrong_usages' --tags llama.cpp
behave --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey' --tags llama.cpp
else
behave "$@"
fi

View File

@@ -126,8 +126,7 @@ static inline void server_log(const char *level, const char *function, int line,
for (const auto& el : log.items())
{
const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace);
snprintf(buf, 1024, " %s=%s", el.key().c_str(), value.c_str());
ss << buf;
ss << " " << el.key() << "=" << value;
}
const std::string str = ss.str();

View File

@@ -6,3 +6,4 @@ More info:
- https://github.com/ggerganov/llama.cpp/pull/2926
- https://github.com/ggerganov/llama.cpp/pull/3624
- https://github.com/ggerganov/llama.cpp/pull/5625

View File

@@ -5,6 +5,7 @@
#include <cstdio>
#include <string>
#include <vector>
#include <set>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
@@ -18,6 +19,7 @@ struct seq_draft {
std::vector<int> i_batch_tgt;
std::vector<llama_token> tokens;
std::vector<std::vector<llama_token_data>> dists;
struct llama_sampling_context * ctx_sampling;
};
@@ -37,12 +39,15 @@ int main(int argc, char ** argv) {
// max number of parallel drafting sequences (i.e. tree branches)
const int n_seq_dft = params.n_parallel;
// probability threshold for accepting a token from the draft model
const float p_accept = params.p_accept;
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
const float p_split = params.p_split;
if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
}
std::default_random_engine rng(params.seed);
std::uniform_real_distribution<> u_dist;
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("speculative", "log"));
LOG_TEE("Log start\n");
@@ -166,7 +171,9 @@ int main(int argc, char ** argv) {
std::vector<seq_draft> drafts(n_seq_dft);
params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
if (params.sparams.temp == 0) {
params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
}
for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
@@ -182,12 +189,15 @@ int main(int argc, char ** argv) {
drafts[0].i_batch_tgt[0] = 0;
while (true) {
std::set<int> active_seqs = {};
// print current draft sequences
for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].active) {
continue;
}
active_seqs.insert(s);
const auto & tokens = drafts[s].tokens;
LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str());
@@ -196,48 +206,156 @@ int main(int argc, char ** argv) {
int i_dft = 0;
int s_keep = 0;
llama_token token_id;
std::string token_str;
// loop until we fail to accept a drafted token or we run out of drafted tokens
while (true) {
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
// sample from the target model
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
llama_sampling_accept(ctx_sampling, ctx_tgt, id, true);
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
const std::string token_str = llama_token_to_piece(ctx_tgt, id);
if (!params.use_color) {
printf("%s", token_str.c_str());
}
if (id == llama_token_eos(model_tgt)) {
has_eos = true;
}
++n_predict;
// check if the target token matches any of the drafts
// for stochastic sampling, attempt to match the token with the drafted tokens
{
bool matches = false;
bool accept = false;
if (params.sparams.temp > 0) {
// stochastic verification
for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].active) {
continue;
llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
float p_tgt = 0, p_dft = 0;
// GGML_ASSERT(dist_tgt.size() == dist_dft.size());
while (active_seqs.size() > 0) {
// randomly select a sequence to verify from active sequences
std::uniform_int_distribution<unsigned int> u_int_dist(0, active_seqs.size() - 1);
int s = *std::next(active_seqs.begin(), u_int_dist(rng));
if (i_dft >= (int) drafts[s].tokens.size()) {
drafts[s].active = false;
active_seqs.erase(s);
continue;
}
if (accept) {
// if we already accepted a token, we can skip the rest
if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) {
drafts[s].active = false;
active_seqs.erase(s);
}
continue;
}
LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
float r = u_dist(rng);
llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true };
// acquire the token probabilities assigned by the draft and target models
for (size_t i = 0; i < dist_tgt.size; i++) {
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
p_tgt = dist_tgt.data[i].p;
}
if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) {
p_dft = dist_dft.data[i].p;
}
if (p_tgt && p_dft) {
break;
}
}
LOG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt);
if (r <= p_tgt / p_dft) {
s_keep = s;
accept = true;
token_id = drafts[s].tokens[i_dft];
token_str = llama_token_to_piece(ctx_tgt, token_id);
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
break;
} else {
LOG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str());
drafts[s].active = false;
// calculate residual probability
GGML_ASSERT(dist_tgt.sorted);
GGML_ASSERT(dist_dft.sorted);
float sum_probs = 0.0f;
// sort dist by id
std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
return a.id < b.id;
});
std::sort(dist_dft.data, dist_dft.data + dist_dft.size, [](const llama_token_data &a, const llama_token_data &b) {
return a.id < b.id;
});
for (size_t i = 0; i < dist_tgt.size; i++) {
dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
sum_probs += dist_tgt.data[i].p;
}
for (size_t i = 0; i < dist_tgt.size; i++) {
dist_tgt.data[i].p /= sum_probs;
}
// sort dist_tgt by p desc
std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
return a.p > b.p;
});
}
active_seqs.erase(s);
for(int i = 0; i < n_seq_dft; i++) {
if (i == s) {
continue;
}
if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
// synchronize active status for sequences with the same drafted token
drafts[i].active = drafts[i].active && accept;
if (!drafts[i].active) {
active_seqs.erase(s);
}
}
}
}
if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) {
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str());
if (!accept) {
// all drafted tokens were rejected
// sample from the target model
LOG("all drafted tokens were rejected, sampling from residual distribution\n");
token_id = llama_sample_token(ctx_tgt, &dist_tgt);
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
token_str = llama_token_to_piece(ctx_tgt, token_id);
}
s_keep = s;
matches = true;
} else {
drafts[s].active = false;
} else {
// greedy verification
// sample from the target model
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
token_str = llama_token_to_piece(ctx_tgt, token_id);
for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].active) {
continue;
}
if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) {
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str());
s_keep = s;
accept = true;
} else {
drafts[s].active = false;
}
}
}
if (matches) {
if (token_id == llama_token_eos(model_tgt)) {
has_eos = true;
}
++n_predict;
if (accept) {
++n_accept;
++n_past_tgt;
++n_past_dft;
@@ -245,17 +363,21 @@ int main(int argc, char ** argv) {
if (params.use_color) {
// Color token according to its origin sequence
printf("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str());
fflush(stdout);
} else {
printf("%s", token_str.c_str());
}
fflush(stdout);
continue;
} else {
printf("%s", token_str.c_str());
fflush(stdout);
break;
}
}
if (params.use_color) {
printf("%s", token_str.c_str());
}
fflush(stdout);
}
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());
{
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str());
// TODO: simplify
{
@@ -275,21 +397,21 @@ int main(int argc, char ** argv) {
drafts[s].active = false;
drafts[s].tokens.clear();
drafts[s].i_batch_tgt.clear();
drafts[s].dists.clear();
}
// note: will be erased after the speculation phase
drafts[0].tokens.push_back(id);
drafts[0].tokens.push_back(token_id);
drafts[0].dists.push_back(std::vector<llama_token_data>());
drafts[0].i_batch_tgt.push_back(0);
llama_batch_clear(batch_dft);
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
llama_decode (ctx_dft, batch_dft);
llama_decode(ctx_dft, batch_dft);
++n_past_dft;
break;
}
if (n_predict > params.n_predict || has_eos) {
@@ -334,12 +456,6 @@ int main(int argc, char ** argv) {
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
}
if (cur_p[0].p < p_accept) {
LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept);
drafts[s].drafting = false;
continue;
}
std::vector<int> sa(1, s);
// attempt to split the branch if the probability is high enough
@@ -367,6 +483,7 @@ int main(int argc, char ** argv) {
drafts[n_seq_cur].skip = true;
drafts[n_seq_cur].tokens = drafts[s].tokens;
drafts[n_seq_cur].dists = drafts[s].dists;
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
@@ -389,6 +506,8 @@ int main(int argc, char ** argv) {
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
drafts[s].tokens.push_back(id);
// save cur_p.data into drafts[s].dists
drafts[s].dists.push_back(cur_p);
// add unique drafted tokens to the target batch
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
@@ -440,6 +559,7 @@ int main(int argc, char ** argv) {
}
drafts[s].tokens.erase(drafts[s].tokens.begin());
drafts[s].dists.erase(drafts[s].dists.begin());
}
}

18
flake.lock generated
View File

@@ -5,11 +5,11 @@
"nixpkgs-lib": "nixpkgs-lib"
},
"locked": {
"lastModified": 1706830856,
"narHash": "sha256-a0NYyp+h9hlb7ddVz4LUn1vT/PLwqfrWYcHMvFB1xYg=",
"lastModified": 1709336216,
"narHash": "sha256-Dt/wOWeW6Sqm11Yh+2+t0dfEWxoMxGBvv3JpIocFl9E=",
"owner": "hercules-ci",
"repo": "flake-parts",
"rev": "b253292d9c0a5ead9bc98c4e9a26c6312e27d69f",
"rev": "f7b3c975cf067e56e7cda6cb098ebe3fb4d74ca2",
"type": "github"
},
"original": {
@@ -20,11 +20,11 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1708655239,
"narHash": "sha256-ZrP/yACUvDB+zbqYJsln4iwotbH6CTZiTkANJ0AgDv4=",
"lastModified": 1709237383,
"narHash": "sha256-cy6ArO4k5qTx+l5o+0mL9f5fa86tYUX3ozE1S+Txlds=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "cbc4211f0afffe6dfd2478a62615dd5175a13f9a",
"rev": "1536926ef5621b09bba54035ae2bb6d806d72ac8",
"type": "github"
},
"original": {
@@ -37,11 +37,11 @@
"nixpkgs-lib": {
"locked": {
"dir": "lib",
"lastModified": 1706550542,
"narHash": "sha256-UcsnCG6wx++23yeER4Hg18CXWbgNpqNXcHIo5/1Y+hc=",
"lastModified": 1709237383,
"narHash": "sha256-cy6ArO4k5qTx+l5o+0mL9f5fa86tYUX3ozE1S+Txlds=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "97b17f32362e475016f942bbdfda4a4a72a8a652",
"rev": "1536926ef5621b09bba54035ae2bb6d806d72ac8",
"type": "github"
},
"original": {

View File

@@ -91,13 +91,14 @@ extern "C" {
// (optional) complete all pending operations
void (*GGML_CALL synchronize)(ggml_backend_t backend);
// compute graph with a plan
// create a plan for ggml_cgraph and free it
ggml_backend_graph_plan_t (*GGML_CALL graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
void (*GGML_CALL graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
void (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
// compute graph with a plan
enum ggml_status (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
// compute graph without a plan (async)
bool (*GGML_CALL graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
enum ggml_status (*GGML_CALL graph_compute) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
// check if the backend supports an operation
bool (*GGML_CALL supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);

View File

@@ -262,11 +262,11 @@ void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_pla
backend->iface.graph_plan_free(backend, plan);
}
void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
backend->iface.graph_plan_compute(backend, plan);
enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
return backend->iface.graph_plan_compute(backend, plan);
}
bool ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
return backend->iface.graph_compute(backend, cgraph);
}
@@ -732,15 +732,15 @@ GGML_CALL static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, g
GGML_UNUSED(backend);
}
GGML_CALL static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
GGML_CALL static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
return ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
GGML_UNUSED(backend);
}
GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
GGML_CALL static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads);
@@ -755,8 +755,7 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str
cplan.abort_callback = cpu_ctx->abort_callback;
cplan.abort_callback_data = cpu_ctx->abort_callback_data;
ggml_graph_compute(cgraph, &cplan);
return true;
return ggml_graph_compute(cgraph, &cplan);
}
GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
@@ -1437,7 +1436,7 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
return true;
}
static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
uint64_t copy_us[GGML_MAX_BACKENDS] = {0};
uint64_t compute_us[GGML_MAX_BACKENDS] = {0};
@@ -1472,8 +1471,9 @@ static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
uint64_t compute_start_us = ggml_time_us();
if (!sched->callback_eval) {
if (!ggml_backend_graph_compute(split_backend, &split->graph)) {
return false;
enum ggml_status ec = ggml_backend_graph_compute(split_backend, &split->graph);
if (ec != GGML_STATUS_SUCCESS) {
return ec;
}
//ggml_backend_synchronize(split_backend); // necessary to measure compute time
} else {
@@ -1494,8 +1494,9 @@ static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);
if (!ggml_backend_graph_compute(split_backend, &gv)) {
return false;
enum ggml_status ec = ggml_backend_graph_compute(split_backend, &gv);
if (ec != GGML_STATUS_SUCCESS) {
return ec;
}
if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) {
@@ -1519,7 +1520,7 @@ static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
}
#endif
return true;
return GGML_STATUS_SUCCESS;
}
ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size) {
@@ -1581,7 +1582,7 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
return true;
}
bool ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS);
if (!sched->is_reset) {
@@ -1590,14 +1591,10 @@ bool ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cg
ggml_backend_sched_split_graph(sched, graph);
if (!ggml_backend_sched_alloc_splits(sched)) {
return false;
return GGML_STATUS_ALLOC_FAILED;
}
if (!ggml_backend_sched_compute_splits(sched)) {
return false;
}
return true;
return ggml_backend_sched_compute_splits(sched);
}
void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {

View File

@@ -66,12 +66,13 @@ extern "C" {
GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create (ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API bool ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op);
GGML_API enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API enum ggml_status ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);
// tensor copy between different backends
GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
@@ -157,26 +158,26 @@ extern "C" {
typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
// Initialize a backend scheduler
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
// Initialize backend buffers from a measure graph
GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
// Get the number of splits of the last graph
GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
GGML_API ggml_backend_t ggml_backend_sched_get_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
GGML_API ggml_backend_t ggml_backend_sched_get_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
// Allocate and compute graph on the backend scheduler
GGML_API bool ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
// Reset all assignments and allocators - must be called before changing the node backends
GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
// Set a callback to be called for each resulting node during graph compute
GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
//
// Utils

View File

@@ -616,6 +616,8 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + Q
#define CUDA_UPSCALE_BLOCK_SIZE 256
#define CUDA_CONCAT_BLOCK_SIZE 256
#define CUDA_PAD_BLOCK_SIZE 256
#define CUDA_ARANGE_BLOCK_SIZE 256
#define CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
#define CUDA_ACC_BLOCK_SIZE 256
#define CUDA_IM2COL_BLOCK_SIZE 256
#define CUDA_POOL2D_BLOCK_SIZE 256
@@ -990,17 +992,21 @@ static __global__ void concat_f32(const float * x,const float * y, float * dst,
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
dst[offset_dst] = x[offset_src];
dst[offset_dst] = x[offset_src];
} else {
int offset_src =
nidx +
blockIdx.y * ne0 +
(blockIdx.z - ne02) * ne0 * gridDim.y;
dst[offset_dst] = y[offset_src];
dst[offset_dst] = y[offset_src];
}
}
static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int nb02, const int scale_factor) {
static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int ne00xne01, const int scale_factor) {
// blockIdx.z: idx of ne02*ne03
// blockIdx.y: idx of ne01*scale_factor aka ne1
// blockIDx.x: idx of ne00*scale_factor / BLOCK_SIZE
// ne00xne01: ne00 * ne01
int ne0 = ne00 * scale_factor;
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
@@ -1012,7 +1018,7 @@ static __global__ void upscale_f32(const float * x, float * dst, const int ne00,
int offset_src =
i00 +
i01 * ne00 +
blockIdx.z * nb02;
blockIdx.z * ne00xne01;
int offset_dst =
nidx +
blockIdx.y * ne0 +
@@ -1020,7 +1026,10 @@ static __global__ void upscale_f32(const float * x, float * dst, const int ne00,
dst[offset_dst] = x[offset_src];
}
static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02) {
static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
// blockIdx.z: idx of ne2*ne3, aka ne02*ne03
// blockIdx.y: idx of ne1
// blockIDx.x: idx of ne0 / BLOCK_SIZE
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
@@ -1031,19 +1040,53 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02) {
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
int offset_src =
nidx +
blockIdx.y * ne00 +
blockIdx.z * ne00 * ne01;
dst[offset_dst] = x[offset_src];
dst[offset_dst] = x[offset_src];
} else {
dst[offset_dst] = 0.0f;
}
}
static __global__ void arange_f32(float * dst, const int ne0, const float start, const float step) {
// blockIDx.x: idx of ne0 / BLOCK_SIZE
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
}
dst[nidx] = start + step * nidx;
}
static __global__ void timestep_embedding_f32(const float * timesteps, float * dst, const int nb1, const int dim, const int max_period) {
// blockIDx.y: idx of timesteps->ne[0]
// blockIDx.x: idx of ((dim + 1) / 2) / BLOCK_SIZE
int i = blockIdx.y;
int j = threadIdx.x + blockIdx.x * blockDim.x;
float * embed_data = (float *)((char *)dst + i*nb1);
if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
embed_data[dim] = 0.f;
}
int half = dim / 2;
if (j >= half) {
return;
}
float timestep = timesteps[i];
float freq = (float)expf(-logf(max_period) * j / half);
float arg = timestep * freq;
embed_data[j] = cosf(arg);
embed_data[j + half] = sinf(arg);
}
template <int block_size>
static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
// blockIdx.x: num_groups idx
// threadIdx.x: block_size idx
int start = blockIdx.x * group_size;
int end = start + group_size;
@@ -2018,74 +2061,73 @@ static const __device__ uint32_t iq3xxs_grid[256] = {
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
};
static const __device__ uint32_t iq3xs_grid[512] = {
0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
static const __device__ uint32_t iq3s_grid[512] = {
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
};
static const __device__ uint64_t iq1s_grid[512] = {
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
@@ -2392,9 +2434,9 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * qs = x[i].qs + 8*ib;
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f;
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
const uint8_t signs = x[i].signs[4*ib + il];
for (int j = 0; j < 4; ++j) {
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
@@ -5211,8 +5253,8 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
const int8_t * q8 = bq8_1[ib32].qs;
int sumi = 0;
for (int l = 0; l < 4; ++l) {
const uint32_t * grid1 = iq3xs_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
const uint32_t * grid2 = iq3xs_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
const uint32_t * grid1 = iq3s_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
const uint32_t * grid2 = iq3s_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
const int grid_l = __vsub4(grid1[0] ^ signs0, signs0);
@@ -5221,7 +5263,7 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
q8 += 8;
}
const float d = (float)bq2->d * (0.5f + ((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds) * 0.5f;
const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds);
return d * sumi;
#else
assert(false);
@@ -6449,7 +6491,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) {
return;
@@ -6457,17 +6499,17 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
// then combine those indices with the corresponding byte offsets to get the total offsets
const int i03 = i/(ne00 * ne01 * ne02);
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
const int64_t i03 = i/(ne00 * ne01 * ne02);
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
const int i13 = i/(ne10 * ne11 * ne12);
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
const int64_t i13 = i/(ne10 * ne11 * ne12);
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
cpy_1(cx + x_offset, cdst + dst_offset);
}
@@ -6905,6 +6947,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
// find the sum of exps in the block
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__syncthreads();
if (warp_id == 0) {
buf_iw[lane_id] = 0.0f;
}
@@ -6956,23 +6999,23 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
template <typename T>
static __global__ void im2col_kernel(
const float * x, T * dst, int batch_offset,
int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW,
const float * x, T * dst, int64_t batch_offset,
int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
int s0, int s1, int p0, int p1, int d0, int d1) {
const int i = threadIdx.x + blockIdx.x * blockDim.x;
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= pelements) {
return;
}
const int ksize = OW * (KH > 1 ? KW : 1);
const int kx = i / ksize;
const int kd = kx * ksize;
const int ky = (i - kd) / OW;
const int ix = i % OW;
const int64_t ksize = OW * (KH > 1 ? KW : 1);
const int64_t kx = i / ksize;
const int64_t kd = kx * ksize;
const int64_t ky = (i - kd) / OW;
const int64_t ix = i % OW;
const int oh = blockIdx.y;
const int batch = blockIdx.z / IC;
const int ic = blockIdx.z % IC;
const int64_t oh = blockIdx.y;
const int64_t batch = blockIdx.z / IC;
const int64_t ic = blockIdx.z % IC;
const int64_t iiw = ix * s0 + kx * d0 - p0;
const int64_t iih = oh * s1 + ky * d1 - p1;
@@ -7298,19 +7341,33 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, const
concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
}
static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int scale_factor, cudaStream_t stream) {
static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int ne03,
const int scale_factor, cudaStream_t stream) {
int ne0 = (ne00 * scale_factor);
int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02);
dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02*ne03);
upscale_f32<<<gridDim, CUDA_UPSCALE_BLOCK_SIZE, 0, stream>>>(x, dst, ne00, ne00 * ne01, scale_factor);
}
static void pad_f32_cuda(const float * x, float * dst,
const int ne00, const int ne01, const int ne02,
const int ne0, const int ne1, const int ne2, cudaStream_t stream) {
const int ne00, const int ne01, const int ne02, const int ne03,
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
dim3 gridDim(num_blocks, ne1, ne2);
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02);
dim3 gridDim(num_blocks, ne1, ne2*ne3);
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
}
static void arange_f32_cuda(float * dst, const int ne0, const float start, const float step, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_ARANGE_BLOCK_SIZE - 1) / CUDA_ARANGE_BLOCK_SIZE;
arange_f32<<<num_blocks, CUDA_ARANGE_BLOCK_SIZE, 0, stream>>>(dst, ne0, start, step);
}
static void timestep_embedding_f32_cuda(const float * x, float * dst, const int ne00, const int nb1,
const int dim, const int max_period, cudaStream_t stream) {
int half_ceil = (dim + 1) / 2;
int num_blocks = (half_ceil + CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE;
dim3 gridDim(num_blocks, ne00, 1);
timestep_embedding_f32<<<gridDim, CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE, 0, stream>>>(x, dst, nb1, dim, max_period);
}
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
@@ -8443,8 +8500,8 @@ static void soft_max_f32_cuda(const float * x, const float * mask, const float *
template <typename T>
static void im2col_cuda(const float* x, T* dst,
int IW, int IH, int OW, int OH, int KW, int KH, int IC,
int batch, int batch_offset, int offset_delta,
int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
int64_t batch, int64_t batch_offset, int64_t offset_delta,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
const int parallel_elements = OW * KW * KH;
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
@@ -9123,7 +9180,7 @@ static void ggml_cuda_op_group_norm(
int num_groups = dst->op_params[0];
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
group_norm_f32_cuda(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream);
group_norm_f32_cuda(src0_dd, dst_dd, num_groups * src0->ne[3], group_size, ggml_nelements(src0), main_stream);
(void) src1;
(void) dst;
@@ -9156,7 +9213,7 @@ static void ggml_cuda_op_upscale(
const int scale_factor = dst->op_params[0];
upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], scale_factor, main_stream);
(void) src1;
(void) dst;
@@ -9172,8 +9229,49 @@ static void ggml_cuda_op_pad(
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
pad_f32_cuda(src0_dd, dst_dd,
src0->ne[0], src0->ne[1], src0->ne[2],
dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], main_stream);
(void) src1;
(void) dst;
(void) src1_dd;
}
static void ggml_cuda_op_arange(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(dst->type == GGML_TYPE_F32);
float start;
float stop;
float step;
memcpy(&start, (float *)dst->op_params + 0, sizeof(float));
memcpy(&stop, (float *)dst->op_params + 1, sizeof(float));
memcpy(&step, (float *)dst->op_params + 2, sizeof(float));
int64_t steps = (int64_t)ceil((stop - start) / step);
GGML_ASSERT(ggml_nelements(dst) == steps);
arange_f32_cuda(dst_dd, dst->ne[0], start, step, main_stream);
(void) src0;
(void) src1;
(void) src0_dd;
(void) src1_dd;
}
static void ggml_cuda_op_timestep_embedding(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int dim = dst->op_params[0];
const int max_period = dst->op_params[1];
timestep_embedding_f32_cuda(src0_dd, dst_dd, src0->ne[0], dst->nb[1], dim, max_period, main_stream);
(void) src1;
(void) dst;
@@ -10458,6 +10556,45 @@ static void ggml_cuda_pad(const ggml_tensor * src0, const ggml_tensor * src1, gg
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pad);
}
static void ggml_cuda_arange(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU;
// dd = data device
float * src0_ddf = nullptr;
float * src1_ddf = nullptr;
float * dst_ddf = nullptr;
cuda_pool_alloc<float> dst_f;
ggml_cuda_set_device(g_main_device);
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
if (dst_on_device) {
dst_ddf = (float *) dst_extra->data_device[g_main_device];
} else {
dst_ddf = dst_f.alloc(ggml_nelements(dst));
}
// do the computation
ggml_cuda_op_arange(src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
CUDA_CHECK(cudaGetLastError());
// copy dst to host if necessary
if (!dst_on_device) {
CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
}
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
CUDA_CHECK(cudaDeviceSynchronize());
}
}
static void ggml_cuda_timestep_embedding(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_timestep_embedding);
}
static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
}
@@ -11358,6 +11495,12 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
case GGML_OP_PAD:
func = ggml_cuda_pad;
break;
case GGML_OP_ARANGE:
func = ggml_cuda_arange;
break;
case GGML_OP_TIMESTEP_EMBEDDING:
func = ggml_cuda_timestep_embedding;
break;
case GGML_OP_LEAKY_RELU:
func = ggml_cuda_leaky_relu;
break;
@@ -12098,7 +12241,7 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
UNUSED(backend);
}
GGML_CALL static bool ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
ggml_cuda_set_main_device(cuda_ctx->device);
@@ -12134,7 +12277,7 @@ GGML_CALL static bool ggml_backend_cuda_graph_compute(ggml_backend_t backend, gg
GGML_ASSERT(ok);
}
return true;
return GGML_STATUS_SUCCESS;
}
GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
@@ -12253,6 +12396,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_GROUP_NORM:
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
return true;
default:

View File

@@ -1927,10 +1927,10 @@ static ggml_backend_buffer_type_t ggml_backend_kompute_get_default_buffer_type(g
return ggml_backend_kompute_buffer_type(ctx->device);
}
static bool ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
ggml_vk_graph_compute(ctx, cgraph);
return true;
return GGML_STATUS_SUCCESS;
}
static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {

View File

@@ -163,6 +163,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
GGML_METAL_KERNEL_TYPE_PAD_F32,
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
@@ -569,6 +571,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
@@ -697,6 +701,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
return false;
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
case GGML_OP_LEAKY_RELU:
return true;
@@ -742,7 +748,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
}
}
static bool ggml_metal_graph_compute(
static enum ggml_status ggml_metal_graph_compute(
struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) {
@@ -1091,7 +1097,8 @@ static bool ggml_metal_graph_compute(
{
GGML_ASSERT(ggml_is_contiguous(src0));
const float scale = *(const float *) dst->op_params;
float scale;
memcpy(&scale, dst->op_params, sizeof(scale));
int64_t n = ggml_nelements(dst);
@@ -1250,11 +1257,15 @@ static bool ggml_metal_graph_compute(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
}
const float scale = ((float *) dst->op_params)[0];
const float max_bias = ((float *) dst->op_params)[1];
float scale;
float max_bias;
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src0->ne[1];
const uint32_t n_head_kv = nrows_x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
@@ -2086,6 +2097,7 @@ static bool ggml_metal_graph_compute(
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
@@ -2300,6 +2312,50 @@ static bool ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_ARANGE:
{
GGML_ASSERT(dst->type == GGML_TYPE_F32);
float start;
float step;
memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float));
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
[encoder setBytes:&start length:sizeof(start) atIndex:2];
[encoder setBytes:&step length:sizeof(step) atIndex:3];
const int nth = MIN(1024, ne0);
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_TIMESTEP_EMBEDDING:
{
GGML_ASSERT(src0->type == GGML_TYPE_F32);
const int dim = dst->op_params[0];
const int max_period = dst->op_params[1];
const int half = dim / 2;
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
[encoder setBytes:&dim length:sizeof(dim) atIndex:3];
[encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
const int nth = MIN(1024, half);
[encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_ARGSORT:
{
GGML_ASSERT(src0->type == GGML_TYPE_F32);
@@ -2428,7 +2484,7 @@ static bool ggml_metal_graph_compute(
MTLCommandBufferStatus status = [command_buffer status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
return false;
return GGML_STATUS_FAILED;
}
}
@@ -2437,7 +2493,7 @@ static bool ggml_metal_graph_compute(
}
}
return true;
return GGML_STATUS_SUCCESS;
}
////////////////////////////////////////////////////////////////////////////////
@@ -2739,7 +2795,7 @@ GGML_CALL static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffe
UNUSED(backend);
}
GGML_CALL static bool ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
GGML_CALL static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
return ggml_metal_graph_compute(metal_ctx, cgraph);

View File

@@ -1959,6 +1959,49 @@ kernel void kernel_pad_f32(
}
}
kernel void kernel_arange_f32(
device char * dst,
constant int64_t & ne0,
constant float & start,
constant float & step,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
device float * dst_ptr = (device float *) dst;
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
dst_ptr[i0] = start + step * i0;
}
}
kernel void kernel_timestep_embedding_f32(
device const char * src0,
device char * dst,
constant uint64_t & nb1,
constant int & dim,
constant int & max_period,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
int i = tgpig.x;
device float * embed_data = (device float *)(dst + i*nb1);
int half_ = dim / 2;
for (int j = tpitg.x; j < half_; j += ntg.x) {
float timestep = ((device float *)src0)[i];
float freq = (float)exp(-log((float)max_period) * j / half_);
float arg = timestep * freq;
embed_data[j ] = cos(arg);
embed_data[j + half_] = sin(arg);
}
if (dim % 2 != 0 && tpitg.x == 0) {
embed_data[dim] = 0.f;
}
}
// bitonic sort implementation following the CUDA kernels as reference
typedef void (argsort_t)(
device const float * x,
@@ -4087,71 +4130,71 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
};
constexpr constant static uint32_t iq3xs_grid[512] = {
0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
constexpr constant static uint32_t iq3s_grid[512] = {
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
};
#define NGRID_IQ1S 512
@@ -4742,7 +4785,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
{
int nval = 8;
int pos = (32*sgitg + tiisg)*nval;
for (int i = 0; i < nval; ++i) values[pos + i] = iq3xs_grid[pos + i];
for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i];
threadgroup_barrier(mem_flags::mem_threadgroup);
}
@@ -4769,12 +4812,14 @@ void kernel_mul_mv_iq3_s_f32_impl(
for (int row = 0; row < N_DST; row++) {
const float db = dh[0];
const float d = db * (0.5f + ((sc[0] >> 4*(ib%2)) & 0xf));
const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
float2 sum = {0};
for (int l = 0; l < 4; ++l) {
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values;
const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values;
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
for (int j = 0; j < 4; ++j) {
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
@@ -4795,7 +4840,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
}
}
}
@@ -5685,15 +5730,15 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 &
device const uint8_t * qs = xb->qs + 8*ib32;
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
const uint8_t qh = xb->qh[ib32] >> 4*il;
const float dl = d * (0.5f + ((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * 0.5f;
constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh << 8) & 256)));
constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh << 7) & 256)));
const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
}
grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh << 6) & 256)));
grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh << 5) & 256)));
grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
for (int i = 0; i < 4; ++i) {
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);

View File

@@ -2231,7 +2231,7 @@ static ggml_backend_buffer_type_t ggml_backend_opencl_get_default_buffer_type(gg
GGML_UNUSED(backend);
}
static bool ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
for (int i = 0; i < graph->n_nodes; ++i) {
ggml_tensor * node = graph->nodes[i];
switch (node->op) {
@@ -2246,7 +2246,7 @@ static bool ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgrap
}
}
return true;
return GGML_STATUS_SUCCESS;
GGML_UNUSED(backend);
}

View File

@@ -51,6 +51,7 @@
#define UNUSED GGML_UNUSED
// some compilers don't provide _mm256_set_m128i, e.g. gcc 7
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
@@ -3818,71 +3819,71 @@ static const uint32_t iq3xxs_grid[256] = {
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
};
static const uint32_t iq3xs_grid[512] = {
0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
static const uint32_t iq3s_grid[512] = {
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
};
#define NGRID_IQ2XXS 512
@@ -4162,11 +4163,11 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in
const uint8_t * signs = x[i].signs;
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
const float db1 = d * (0.5f + (x[i].scales[ib32/2] & 0xf)) * 0.5f;
const float db2 = d * (0.5f + (x[i].scales[ib32/2] >> 4)) * 0.5f;
const float db1 = d * (1 + 2*(x[i].scales[ib32/2] & 0xf));
const float db2 = d * (1 + 2*(x[i].scales[ib32/2] >> 4));
for (int l = 0; l < 4; ++l) {
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
for (int j = 0; j < 4; ++j) {
y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
@@ -4176,8 +4177,8 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in
qs += 8;
signs += 4;
for (int l = 0; l < 4; ++l) {
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)));
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)));
for (int j = 0; j < 4; ++j) {
y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
@@ -9563,7 +9564,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
const __m128i odd_bits = _mm_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);
const __m128i full_sign_bits = _mm_or_si128(partial_sign_bits, odd_bits);
const __m256i full_signs = _mm256_set_m128i(full_sign_bits, full_sign_bits);
const __m256i full_signs = MM256_SET_M128I(full_sign_bits, full_sign_bits);
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)(y[i].qs+32));
@@ -9585,8 +9586,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
const __m256i sc1 = _mm256_set_m128i(_mm_set1_epi16(2*(x[i].scales[0] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[0] & 0xf)+1));
const __m256i sc2 = _mm256_set_m128i(_mm_set1_epi16(2*(x[i].scales[1] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[1] & 0xf)+1));
const __m256i sc1 = MM256_SET_M128I(_mm_set1_epi16(2*(x[i].scales[0] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[0] & 0xf)+1));
const __m256i sc2 = MM256_SET_M128I(_mm_set1_epi16(2*(x[i].scales[1] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[1] & 0xf)+1));
const __m256i sum = _mm256_add_epi32(_mm256_madd_epi16(sc1, dot1), _mm256_madd_epi16(sc2, dot2));
@@ -9653,8 +9654,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits);
const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1);
const __m256i full_signs_1 = _mm256_set_m128i(full_signs_l, full_signs_l);
const __m256i full_signs_2 = _mm256_set_m128i(full_signs_h, full_signs_h);
const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l);
const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h);
__m256i signs;
signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1);
@@ -10089,18 +10090,34 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
#if defined(__ARM_NEON)
typedef union {
uint16x8_t vec_index;
uint16_t index[8];
} vec_index_t;
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
};
static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1);
const uint8x16_t mask2 = vld1q_u8(k_mask2);
static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1);
const uint8x16_t mask2 = vld1q_u8(k_mask2);
const int16x8_t hshift = vld1q_s16(k_shift);
const uint16x8_t m256 = vdupq_n_u16(256);
const uint8x16_t m1 = vdupq_n_u8(1);
uint8x16x2_t vs;
ggml_int8x16x4_t q3s;
ggml_int8x16x4_t q8b;
vec_index_t idx;
#if QK_K == 256
uint32_t scales32[2];
const uint8_t * scales8 = (const uint8_t *)scales32;
#endif
float sumf = 0;
for (int i = 0; i < nb; ++i) {
@@ -10109,47 +10126,63 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
const uint8_t * restrict qh = x[i].qh;
const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
const int8_t * restrict q8 = y[i].qs;
#if QK_K == 256
memcpy(scales32, x[i].scales, 4);
scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;
scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;
#endif
int sumi1 = 0, sumi2 = 0;
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
const uint32x4_t aux32x4_0 = {iq3xs_grid[qs[ 0] | ((qh[ib32+0] << 8) & 256)], iq3xs_grid[qs[ 1] | ((qh[ib32+0] << 7) & 256)],
iq3xs_grid[qs[ 2] | ((qh[ib32+0] << 6) & 256)], iq3xs_grid[qs[ 3] | ((qh[ib32+0] << 5) & 256)]};
const uint32x4_t aux32x4_1 = {iq3xs_grid[qs[ 4] | ((qh[ib32+0] << 4) & 256)], iq3xs_grid[qs[ 5] | ((qh[ib32+0] << 3) & 256)],
iq3xs_grid[qs[ 6] | ((qh[ib32+0] << 2) & 256)], iq3xs_grid[qs[ 7] | ((qh[ib32+0] << 1) & 256)]};
const uint32x4_t aux32x4_2 = {iq3xs_grid[qs[ 8] | ((qh[ib32+1] << 8) & 256)], iq3xs_grid[qs[ 9] | ((qh[ib32+1] << 7) & 256)],
iq3xs_grid[qs[10] | ((qh[ib32+1] << 6) & 256)], iq3xs_grid[qs[11] | ((qh[ib32+1] << 5) & 256)]};
const uint32x4_t aux32x4_3 = {iq3xs_grid[qs[12] | ((qh[ib32+1] << 4) & 256)], iq3xs_grid[qs[13] | ((qh[ib32+1] << 3) & 256)],
iq3xs_grid[qs[14] | ((qh[ib32+1] << 2) & 256)], iq3xs_grid[qs[15] | ((qh[ib32+1] << 1) & 256)]};
qs += 16;
const uint8x16_t idx_l = vld1q_u8(qs); qs += 16;
idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256));
const uint32x4_t aux32x4_0 = {iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]],
iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]};
const uint32x4_t aux32x4_1 = {iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]],
iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]};
idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256));
const uint32x4_t aux32x4_2 = {iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]],
iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]};
const uint32x4_t aux32x4_3 = {iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]],
iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]};
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
vs.val[0] = vceqq_u8(vs.val[0], mask2);
vs.val[1] = vceqq_u8(vs.val[1], mask2);
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);
q3s.val[0] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_0))), vreinterpretq_s8_u8(vs.val[0]));
q3s.val[1] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_1))), vreinterpretq_s8_u8(vs.val[1]));
q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0));
q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1));
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
vs.val[0] = vceqq_u8(vs.val[0], mask2);
vs.val[1] = vceqq_u8(vs.val[1], mask2);
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);
signs += 4;
q3s.val[2] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_2))), vreinterpretq_s8_u8(vs.val[0]));
q3s.val[3] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_3))), vreinterpretq_s8_u8(vs.val[1]));
q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2));
q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3));
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
#if QK_K == 256
sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0];
sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4];
#else
sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf));
sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4));
#endif
}
sumf += d*(sumi1 + sumi2);
}
*s = 0.25f * sumf;
*s = sumf;
#elif defined(__AVX2__)
@@ -10164,6 +10197,16 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);
const __m256i idx_mask = _mm256_set1_epi32(256);
typedef union {
__m256i vec[2];
uint32_t index[16];
} index_t;
index_t idx;
__m256 accumf = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
@@ -10176,24 +10219,25 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i q2_1 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+0] << 1) & 256)],
iq3xs_grid[qs[6] | ((qh[ib32+0] << 2) & 256)],
iq3xs_grid[qs[5] | ((qh[ib32+0] << 3) & 256)],
iq3xs_grid[qs[4] | ((qh[ib32+0] << 4) & 256)],
iq3xs_grid[qs[3] | ((qh[ib32+0] << 5) & 256)],
iq3xs_grid[qs[2] | ((qh[ib32+0] << 6) & 256)],
iq3xs_grid[qs[1] | ((qh[ib32+0] << 7) & 256)],
iq3xs_grid[qs[0] | ((qh[ib32+0] << 8) & 256)]);
qs += 8;
const __m256i q2_2 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+1] << 1) & 256)],
iq3xs_grid[qs[6] | ((qh[ib32+1] << 2) & 256)],
iq3xs_grid[qs[5] | ((qh[ib32+1] << 3) & 256)],
iq3xs_grid[qs[4] | ((qh[ib32+1] << 4) & 256)],
iq3xs_grid[qs[3] | ((qh[ib32+1] << 5) & 256)],
iq3xs_grid[qs[2] | ((qh[ib32+1] << 6) & 256)],
iq3xs_grid[qs[1] | ((qh[ib32+1] << 7) & 256)],
iq3xs_grid[qs[0] | ((qh[ib32+1] << 8) & 256)]);
qs += 8;
const __m256i idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); qs += 16;
idx.vec[0] = _mm256_set1_epi32(qh[ib32+0]);
idx.vec[1] = _mm256_set1_epi32(qh[ib32+1]);
idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask);
idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask);
idx.vec[0] = _mm256_or_si256(idx.vec[0], _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l)));
idx.vec[1] = _mm256_or_si256(idx.vec[1], _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l, 1)));
// At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange.
//const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4);
//const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4);
const __m256i q2_1 = _mm256_set_epi32(
iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]],
iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]
);
const __m256i q2_2 = _mm256_set_epi32(
iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]],
iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]]
);
__m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
@@ -10221,7 +10265,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
}
*s = 0.25f * hsum_float_8(accumf);
*s = hsum_float_8(accumf);
#else
@@ -10238,8 +10282,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1;
int32_t sumi = 0;
for (int l = 0; l < 4; ++l) {
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));
for (int j = 0; j < 4; ++j) {
sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
@@ -10251,8 +10295,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
bsum += sumi * ls1;
sumi = 0;
for (int l = 0; l < 4; ++l) {
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));
for (int j = 0; j < 4; ++j) {
sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
@@ -10265,7 +10309,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
}
sumf += d * bsum;
}
*s = 0.25f * sumf;
*s = sumf;
#endif
}
@@ -10508,10 +10552,10 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[1].qs);
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[0].qs);
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[1].qs);
const __m256i q4b_1 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
const __m256i q4b_2 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
@@ -10618,10 +10662,10 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs); qs += 16;
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i q4b_1 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
const __m256i q4b_2 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
@@ -11912,7 +11956,8 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
}
float best = 0;
float scale = max/(2*kMaxQ-1);
for (int is = -15; is <= 15; ++is) {
for (int k = 0; k < bs4; ++k) is_on_grid[k] = false;
for (int is = -9; is <= 9; ++is) {
float id = (2*kMaxQ-1+is*0.2f)/max;
float this_scale = 1/id;
for (int k = 0; k < bs4; ++k) {
@@ -11948,7 +11993,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
if (n_not_ongrid > 0 && scale > 0) {
float id = 1/scale;
for (int k = 0; k < bs4; ++k) {
if (is_on_grid[k]) continue;
//if (is_on_grid[k]) continue;
uint16_t u = 0;
for (int i = 0; i < 4; ++i) {
int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
@@ -12004,7 +12049,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
}
float d = max_scale/31;
y[ibl].d = GGML_FP32_TO_FP16(d);
y[ibl].d = GGML_FP32_TO_FP16(d * 1.033f);
float id = 1/d;
for (int ib = 0; ib < QK_K/block_size; ib += 2) {
int l1 = nearest_int(0.5f*(id*scales[ib+0]-1));

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -10,6 +10,7 @@ extern "C" {
#define GGML_VK_NAME "Vulkan"
#define GGML_VK_MAX_DEVICES 16
GGML_API void ggml_vk_instance_init(void);
GGML_API void ggml_vk_init_cpu_assist(void);
GGML_API void ggml_vk_preallocate_buffers_graph_cpu_assist(struct ggml_tensor * node);

237
ggml.c
View File

@@ -320,6 +320,17 @@ static ggml_fp16_t ggml_table_exp_f16[1 << 16];
// precomputed f32 table for f16 (256 KB) (ggml-impl.h)
float ggml_table_f32_f16[1 << 16];
const char * ggml_status_to_string(enum ggml_status status) {
switch (status) {
case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)";
case GGML_STATUS_FAILED: return "GGML status: error (operation failed)";
case GGML_STATUS_SUCCESS: return "GGML status: success";
case GGML_STATUS_ABORTED: return "GGML status: warning (operation aborted)";
}
return "GGML status: unknown";
}
// note: do not use these inside ggml.c
// these are meant to be used via the ggml.h API
float ggml_fp16_to_fp32(ggml_fp16_t x) {
@@ -1822,6 +1833,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"POOL_2D",
"UPSCALE",
"PAD",
"ARANGE",
"TIMESTEP_EMBEDDING",
"ARGSORT",
"LEAKY_RELU",
@@ -1850,7 +1863,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};
static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -1908,6 +1921,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"pool_2d(x)",
"upscale(x)",
"pad(x)",
"arange(start, stop, step)",
"timestep_embedding(timesteps, dim, max_period)",
"argsort(x)",
"leaky_relu(x)",
@@ -1936,7 +1951,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};
static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -2895,11 +2910,21 @@ static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_
return ((const int32_t *)(tensor->op_params))[i];
}
static float ggml_get_op_params_f32(const struct ggml_tensor * tensor, uint32_t i) {
assert(i < GGML_MAX_OP_PARAMS / sizeof(float));
return ((const float *)(tensor->op_params))[i];
}
static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) {
assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
((int32_t *)(tensor->op_params))[i] = value;
}
static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, float value) {
assert(i < GGML_MAX_OP_PARAMS / sizeof(float));
((float *)(tensor->op_params))[i] = value;
}
struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
memset(tensor->data, 0, ggml_nbytes(tensor));
return tensor;
@@ -5898,6 +5923,55 @@ struct ggml_tensor * ggml_upscale(
return ggml_upscale_impl(ctx, a, scale_factor);
}
struct ggml_tensor * ggml_arange(
struct ggml_context * ctx,
float start,
float stop,
float step) {
GGML_ASSERT(stop > start);
const int64_t steps = (int64_t) ceilf((stop - start) / step);
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
result->op = GGML_OP_ARANGE;
ggml_set_op_params_f32(result, 0, start);
ggml_set_op_params_f32(result, 1, stop);
ggml_set_op_params_f32(result, 2, step);
return result;
}
struct ggml_tensor * ggml_timestep_embedding(
struct ggml_context * ctx,
struct ggml_tensor * timesteps,
int dim,
int max_period) {
bool is_node = false;
if (timesteps->grad) {
GGML_ASSERT(false); // TODO: implement backward
is_node = true;
}
int actual_dim = dim;
if (dim % 2 != 0) {
actual_dim = dim + 1;
}
struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, actual_dim, timesteps->ne[0]);
result->op = GGML_OP_TIMESTEP_EMBEDDING;
ggml_set_op_params_i32(result, 0, dim);
ggml_set_op_params_i32(result, 1, max_period);
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = timesteps;
return result;
}
// ggml_argsort
struct ggml_tensor * ggml_argsort(
@@ -10231,7 +10305,7 @@ static void ggml_compute_forward_group_norm_f32(
int n_channels = src0->ne[2];
int n_groups = dst->op_params[0];
int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
for (int i = ith; i < n_groups; i+=nth) {
for (int i = ith; i < n_groups; i += nth) {
int start = i * n_channels_per_group;
int end = start + n_channels_per_group;
if (end > n_channels) {
@@ -10245,28 +10319,32 @@ static void ggml_compute_forward_group_norm_f32(
for (int64_t i01 = 0; i01 < ne01; i01++) {
const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
ggml_float sumr = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += (ggml_float)x[i00];
sumr += (ggml_float)x[i00];
}
sum += sumr;
}
}
float mean = sum / (ne00 * ne01 * step);
ggml_float sum2 = 0.0;
const float mean = sum / (ne00 * ne01 * step);
ggml_float sum2 = 0.0;
for (int64_t i02 = start; i02 < end; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) {
const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
ggml_float sumr = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
float v = x[i00] - mean;
y[i00] = v;
sum2 += (ggml_float)(v * v);
sumr += (ggml_float)(v * v);
}
sum2 += sumr;
}
}
float variance = sum2 / (ne00 * ne01 * step);
const float variance = sum2 / (ne00 * ne01 * step);
const float scale = 1.0f / sqrtf(variance + eps);
for (int64_t i02 = start; i02 < end; i02++) {
@@ -13547,6 +13625,106 @@ static void ggml_compute_forward_pad(
}
}
// ggml_compute_forward_arange
static void ggml_compute_forward_arange_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
return;
}
GGML_ASSERT(dst->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
const float start = ggml_get_op_params_f32(dst, 0);
const float stop = ggml_get_op_params_f32(dst, 1);
const float step = ggml_get_op_params_f32(dst, 2);
const int64_t steps = (int64_t) ceilf((stop - start) / step);
GGML_ASSERT(ggml_nelements(dst) == steps);
for (int64_t i = ith; i < steps; i+= nth) {
float value = start + step * i;
((float *)dst->data)[i] = value;
}
}
static void ggml_compute_forward_arange(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
switch (dst->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_arange_f32(params, dst);
} break;
default:
{
GGML_ASSERT(false);
} break;
}
}
static void ggml_compute_forward_timestep_embedding_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
return;
}
const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
GGML_TENSOR_UNARY_OP_LOCALS
const int dim = ggml_get_op_params_i32(dst, 0);
const int max_period = ggml_get_op_params_i32(dst, 1);
int half = dim / 2;
for (int64_t i = 0; i < ne00; i++) {
float * embed_data = (float *)((char *) dst->data + i*nb1);
for (int64_t j = ith; j < half; j += nth) {
float timestep = ((float *)src0->data)[i];
float freq = (float)expf(-logf(max_period) * j / half);
float arg = timestep * freq;
embed_data[j] = cosf(arg);
embed_data[j + half] = sinf(arg);
}
if (dim % 2 != 0 && ith == 0) {
embed_data[dim] = 0.f;
}
}
}
static void ggml_compute_forward_timestep_embedding(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_timestep_embedding_f32(params, dst);
} break;
default:
{
GGML_ASSERT(false);
} break;
}
}
// ggml_compute_forward_argsort
static void ggml_compute_forward_argsort_f32(
@@ -15615,6 +15793,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_pad(params, tensor);
} break;
case GGML_OP_ARANGE:
{
ggml_compute_forward_arange(params, tensor);
} break;
case GGML_OP_TIMESTEP_EMBEDDING:
{
ggml_compute_forward_timestep_embedding(params, tensor);
} break;
case GGML_OP_ARGSORT:
{
ggml_compute_forward_argsort(params, tensor);
@@ -16617,6 +16803,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_ARANGE:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_TIMESTEP_EMBEDDING:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_ARGSORT:
{
GGML_ASSERT(false); // TODO: not implemented
@@ -17217,6 +17411,7 @@ struct ggml_compute_state {
ggml_thread_t thrd;
int ith;
struct ggml_compute_state_shared * shared;
enum ggml_status ec;
};
static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) {
@@ -17368,6 +17563,14 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
{
n_tasks = n_threads;
} break;
case GGML_OP_ARANGE:
{
n_tasks = n_threads;
} break;
case GGML_OP_TIMESTEP_EMBEDDING:
{
n_tasks = n_threads;
} break;
case GGML_OP_ARGSORT:
{
n_tasks = n_threads;
@@ -17502,7 +17705,8 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
while (true) {
if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
state->shared->node_n += 1;
return (thread_ret_t) GGML_EXIT_ABORTED;
state->ec = GGML_STATUS_ABORTED;
return 0;
}
if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
@@ -17624,7 +17828,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
}
}
return GGML_EXIT_SUCCESS;
return 0;
}
struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threads) {
@@ -17820,7 +18024,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
return cplan;
}
int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
{
GGML_ASSERT(cplan);
GGML_ASSERT(cplan->n_threads > 0);
@@ -17864,6 +18068,7 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
.thrd = 0,
.ith = j,
.shared = &state_shared,
.ec = GGML_STATUS_SUCCESS,
};
const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
@@ -17874,12 +18079,14 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
workers[0].ith = 0;
workers[0].shared = &state_shared;
workers[0].ec = GGML_STATUS_SUCCESS;
const int64_t perf_start_cycles = ggml_perf_cycles();
const int64_t perf_start_time_us = ggml_perf_time_us();
// this is a work thread too
int compute_status = (size_t) ggml_graph_compute_thread(&workers[0]);
ggml_graph_compute_thread(&workers[0]);
enum ggml_status compute_status = workers[0].ec;
// don't leave affinity set on the main thread
clear_numa_thread_affinity();
@@ -17889,6 +18096,8 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
for (int j = 1; j < n_threads; j++) {
const int rc = ggml_thread_join(workers[j].thrd, NULL);
GGML_ASSERT(rc == 0);
if (workers[j].ec != GGML_STATUS_SUCCESS)
compute_status = workers[j].ec;
}
}
@@ -17916,14 +18125,14 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
return compute_status;
}
void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size);
cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
ggml_graph_compute(cgraph, &cplan);
return ggml_graph_compute(cgraph, &cplan);
}
struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) {

34
ggml.h
View File

@@ -315,6 +315,16 @@
extern "C" {
#endif
enum ggml_status {
GGML_STATUS_ALLOC_FAILED = -2,
GGML_STATUS_FAILED = -1,
GGML_STATUS_SUCCESS = 0,
GGML_STATUS_ABORTED = 1,
};
// get ggml_status name string
GGML_API GGML_CALL const char * ggml_status_to_string(enum ggml_status status);
typedef uint16_t ggml_fp16_t;
// convert FP16 <-> FP32
@@ -454,6 +464,8 @@ extern "C" {
GGML_OP_POOL_2D,
GGML_OP_UPSCALE, // nearest interpolate
GGML_OP_PAD,
GGML_OP_ARANGE,
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT,
GGML_OP_LEAKY_RELU,
@@ -1661,6 +1673,15 @@ extern "C" {
int p2,
int p3);
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
// timesteps: [N,]
// return: [N, dim]
GGML_API struct ggml_tensor * ggml_timestep_embedding(
struct ggml_context * ctx,
struct ggml_tensor * timesteps,
int dim,
int max_period);
// sort rows
enum ggml_sort_order {
GGML_SORT_ORDER_ASC,
@@ -1672,6 +1693,12 @@ extern "C" {
struct ggml_tensor * a,
enum ggml_sort_order order);
GGML_API struct ggml_tensor * ggml_arange(
struct ggml_context * ctx,
float start,
float stop,
float step);
// top k elements per row
GGML_API struct ggml_tensor * ggml_top_k(
struct ggml_context * ctx,
@@ -1923,12 +1950,11 @@ extern "C" {
// ggml_graph_plan() has to be called before ggml_graph_compute()
// when plan.work_size > 0, caller must allocate memory for plan.work_data
GGML_API struct ggml_cplan ggml_graph_plan (const struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
GGML_API int ggml_graph_compute( struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
GGML_API struct ggml_cplan ggml_graph_plan (const struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
GGML_API enum ggml_status ggml_graph_compute ( struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
// same as ggml_graph_compute() but the work data is allocated as a part of the context
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
GGML_API void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
GGML_API enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);

File diff suppressed because it is too large Load Diff

View File

@@ -604,20 +604,28 @@ class PoolingType(IntEnum):
class GGMLQuantizationType(IntEnum):
F32 = 0
F16 = 1
Q4_0 = 2
Q4_1 = 3
Q5_0 = 6
Q5_1 = 7
Q8_0 = 8
Q8_1 = 9
Q2_K = 10
Q3_K = 11
Q4_K = 12
Q5_K = 13
Q6_K = 14
Q8_K = 15
F32 = 0
F16 = 1
Q4_0 = 2
Q4_1 = 3
Q5_0 = 6
Q5_1 = 7
Q8_0 = 8
Q8_1 = 9
Q2_K = 10
Q3_K = 11
Q4_K = 12
Q5_K = 13
Q6_K = 14
Q8_K = 15
IQ2_XXS = 16
IQ2_XS = 17
IQ3_XXS = 18
IQ1_S = 19
IQ4_NL = 20
IQ3_S = 21
IQ2_S = 22
IQ4_XS = 23
class GGUFEndian(IntEnum):
@@ -662,20 +670,28 @@ class GGUFValueType(IntEnum):
QK_K = 256
# Items here are (block size, type size)
GGML_QUANT_SIZES = {
GGMLQuantizationType.F32: (1, 4),
GGMLQuantizationType.F16: (1, 2),
GGMLQuantizationType.Q4_0: (32, 2 + 16),
GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16),
GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16),
GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16),
GGMLQuantizationType.Q8_0: (32, 2 + 32),
GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32),
GGMLQuantizationType.Q2_K: (256, 2 + 2 + QK_K // 16 + QK_K // 4),
GGMLQuantizationType.Q3_K: (256, 2 + QK_K // 4 + QK_K // 8 + 12),
GGMLQuantizationType.Q4_K: (256, 2 + 2 + QK_K // 2 + 12),
GGMLQuantizationType.Q5_K: (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12),
GGMLQuantizationType.Q6_K: (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16),
GGMLQuantizationType.Q8_K: (256, 4 + QK_K + QK_K // 8),
GGMLQuantizationType.F32: (1, 4),
GGMLQuantizationType.F16: (1, 2),
GGMLQuantizationType.Q4_0: (32, 2 + 16),
GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16),
GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16),
GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16),
GGMLQuantizationType.Q8_0: (32, 2 + 32),
GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32),
GGMLQuantizationType.Q2_K: (256, 2 + 2 + QK_K // 16 + QK_K // 4),
GGMLQuantizationType.Q3_K: (256, 2 + QK_K // 4 + QK_K // 8 + 12),
GGMLQuantizationType.Q4_K: (256, 2 + 2 + QK_K // 2 + 12),
GGMLQuantizationType.Q5_K: (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12),
GGMLQuantizationType.Q6_K: (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16),
GGMLQuantizationType.Q8_K: (256, 4 + QK_K + QK_K // 8),
GGMLQuantizationType.IQ2_XXS: (256, 2 + QK_K // 4),
GGMLQuantizationType.IQ2_XS: (256, 2 + QK_K // 4 + QK_K // 32),
GGMLQuantizationType.IQ3_XXS: (256, 2 + QK_K // 4 + QK_K // 8),
GGMLQuantizationType.IQ1_S: (256, 2 + QK_K // 8 + QK_K // 16),
GGMLQuantizationType.IQ4_NL: (32, 2 + 16),
GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4),
GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16),
GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64),
}

View File

@@ -362,7 +362,7 @@ class GGUFWriter:
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
def add_pooling_type(self, value: PoolingType) -> None:
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value)
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
def add_rope_dimension_count(self, count: int) -> None:
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)

View File

@@ -15,7 +15,7 @@ array ::=
string ::=
"\"" (
[^"\\] |
[^"\\\x7F\x00-\x1F] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws

View File

@@ -24,7 +24,7 @@ array ::=
string ::=
"\"" (
[^"\\] |
[^"\\\x7F\x00-\x1F] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws

576
llama.cpp

File diff suppressed because it is too large Load Diff

42
llama.h
View File

@@ -129,6 +129,7 @@ extern "C" {
};
enum llama_pooling_type {
LLAMA_POOLING_TYPE_UNSPECIFIED = -1,
LLAMA_POOLING_TYPE_NONE = 0,
LLAMA_POOLING_TYPE_MEAN = 1,
LLAMA_POOLING_TYPE_CLS = 2,
@@ -162,7 +163,7 @@ extern "C" {
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
// - pos : the positions of the respective token in the sequence
// - seq_id : the sequence to which the respective token belongs
// - logits : if zero, the logits for the respective token will not be output
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
//
typedef struct llama_batch {
int32_t n_tokens;
@@ -172,7 +173,7 @@ extern "C" {
llama_pos * pos;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits;
int8_t * logits; // TODO: rename this to "output"
// NOTE: helpers for smooth API transition - can be deprecated in the future
// for future-proof code, use the above fields instead and ignore everything below
@@ -236,7 +237,10 @@ extern "C" {
uint32_t n_batch; // prompt processing maximum batch size
uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing
int32_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
// (ignored if no pooling layer)
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency, 0 = from model
@@ -255,10 +259,15 @@ extern "C" {
enum ggml_type type_v; // data type for V cache
// Keep the booleans together to avoid misalignment during copy-by-value.
bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embedding; // embedding mode only
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
// Abort callback
// if it returns true, execution of llama_decode() will be aborted
// currently works only with CPU execution
ggml_abort_callback abort_callback;
void * abort_callback_data;
};
// model quantization parameters
@@ -632,7 +641,14 @@ extern "C" {
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
// Token logits obtained from the last call to llama_eval()
// Set whether to use causal attention or not
// If set to true, the model will only attend to the past tokens
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
// Set abort callback
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
// Token logits obtained from the last call to llama_decode()
// The logits for the last token are stored in the last row
// Logits for which llama_batch.logits[i] == 0 are undefined
// Rows: n_tokens provided with llama_batch
@@ -643,14 +659,20 @@ extern "C" {
// llama_get_logits(ctx) + i*n_vocab
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
// Get the embeddings for the input
// shape: [n_embd] (1-dimensional)
// Get all output token embeddings
// shape: [n_tokens*n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Get the embeddings for the ith sequence
// Get the embeddings for the ith token
// llama_get_embeddings(ctx) + i*n_embd
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
// Get the embeddings for a sequence id
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
//
// Vocab
//

213
scripts/pod-llama.sh Normal file
View File

@@ -0,0 +1,213 @@
#!/bin/bash
#
# Use this script only on fresh pods (runpod.io)!
# Otherwise, it can break your environment!
#
if [ -z "$1" ]; then
echo "Usage: $0 <data>"
echo " 0: no models"
echo " 1: tinyllama-1b"
echo " 2: codellama-7b"
echo " 3: codellama-13b"
echo " 4: codellama-34b"
echo " 5: codellama-7b-instruct"
echo " 6: codellama-13b-instruct"
echo " 7: codellama-34b-instruct"
exit 1
fi
set -x
# setup deps
apt-get update
apt-get install -y git-lfs cmake cmake-curses-gui vim ruby
git-lfs install
if [ ! -d "/workspace" ]; then
ln -sfn $(pwd) /workspace
fi
# download data
cd /workspace
# this is useful to git clone repos without doubling the disk size due to .git
git clone https://github.com/iboB/git-lfs-download
ln -sfn /workspace/git-lfs-download/git-lfs-download /usr/local/bin/git-lfs-download
# llama.cpp
cd /workspace
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp
LLAMA_CUBLAS=1 make -j
ln -sfn /workspace/TinyLlama-1.1B-Chat-v0.3 ./models/tinyllama-1b
ln -sfn /workspace/CodeLlama-7b-hf ./models/codellama-7b
ln -sfn /workspace/CodeLlama-13b-hf ./models/codellama-13b
ln -sfn /workspace/CodeLlama-34b-hf ./models/codellama-34b
ln -sfn /workspace/CodeLlama-7b-Instruct-hf ./models/codellama-7b-instruct
ln -sfn /workspace/CodeLlama-13b-Instruct-hf ./models/codellama-13b-instruct
ln -sfn /workspace/CodeLlama-34b-Instruct-hf ./models/codellama-34b-instruct
pip install -r requirements.txt
# cmake
cd /workspace/llama.cpp
mkdir build-cublas
cd build-cublas
cmake -DLLAMA_CUBLAS=1 ../
make -j
if [ "$1" -eq "0" ]; then
exit 0
fi
# more models
if [ "$1" -eq "1" ]; then
cd /workspace
git-lfs-download https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.3
cd /workspace/llama.cpp
python3 convert.py ./models/tinyllama-1b --outfile ./models/tinyllama-1b/ggml-model-f16.gguf --outtype f16
./quantize ./models/tinyllama-1b/ggml-model-f16.gguf ./models/tinyllama-1b/ggml-model-q4_0.gguf q4_0
./quantize ./models/tinyllama-1b/ggml-model-f16.gguf ./models/tinyllama-1b/ggml-model-q4_k.gguf q4_k
./quantize ./models/tinyllama-1b/ggml-model-f16.gguf ./models/tinyllama-1b/ggml-model-q8_0.gguf q8_0
fi
if [ "$1" -eq "2" ]; then
cd /workspace
git-lfs-download https://huggingface.co/codellama/CodeLlama-7b-hf --without *safetensors*
rm -v ./CodeLlama-7b-hf/*safetensors*
cd /workspace/llama.cpp
python3 convert.py ./models/codellama-7b --outfile ./models/codellama-7b/ggml-model-f16.gguf --outtype f16
./quantize ./models/codellama-7b/ggml-model-f16.gguf ./models/codellama-7b/ggml-model-q4_0.gguf q4_0
./quantize ./models/codellama-7b/ggml-model-f16.gguf ./models/codellama-7b/ggml-model-q4_k.gguf q4_k
./quantize ./models/codellama-7b/ggml-model-f16.gguf ./models/codellama-7b/ggml-model-q8_0.gguf q8_0
fi
if [ "$1" -eq "3" ]; then
cd /workspace
git-lfs-download https://huggingface.co/codellama/CodeLlama-13b-hf --without *safetensors*
rm -v ./CodeLlama-13b-hf/*safetensors*
cd /workspace/llama.cpp
python3 convert.py ./models/codellama-13b --outfile ./models/codellama-13b/ggml-model-f16.gguf --outtype f16
./quantize ./models/codellama-13b/ggml-model-f16.gguf ./models/codellama-13b/ggml-model-q4_0.gguf q4_0
./quantize ./models/codellama-13b/ggml-model-f16.gguf ./models/codellama-13b/ggml-model-q4_k.gguf q4_k
./quantize ./models/codellama-13b/ggml-model-f16.gguf ./models/codellama-13b/ggml-model-q8_0.gguf q8_0
fi
if [ "$1" -eq "4" ]; then
cd /workspace
git-lfs-download https://huggingface.co/codellama/CodeLlama-34b-hf --without *safetensors*
rm -v ./CodeLlama-34b-hf/*safetensors*
cd /workspace/llama.cpp
python3 convert.py ./models/codellama-34b --outfile ./models/codellama-34b/ggml-model-f16.gguf --outtype f16
./quantize ./models/codellama-34b/ggml-model-f16.gguf ./models/codellama-34b/ggml-model-q4_0.gguf q4_0
./quantize ./models/codellama-34b/ggml-model-f16.gguf ./models/codellama-34b/ggml-model-q4_k.gguf q4_k
./quantize ./models/codellama-34b/ggml-model-f16.gguf ./models/codellama-34b/ggml-model-q8_0.gguf q8_0
fi
if [ "$1" -eq "5" ]; then
cd /workspace
git-lfs-download https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf --without *safetensors*
rm -v ./CodeLlama-7b-Instruct-hf/*safetensors*
cd /workspace/llama.cpp
python3 convert.py ./models/codellama-7b-instruct --outfile ./models/codellama-7b-instruct/ggml-model-f16.gguf --outtype f16
./quantize ./models/codellama-7b-instruct/ggml-model-f16.gguf ./models/codellama-7b-instruct/ggml-model-q4_0.gguf q4_0
./quantize ./models/codellama-7b-instruct/ggml-model-f16.gguf ./models/codellama-7b-instruct/ggml-model-q4_k.gguf q4_k
./quantize ./models/codellama-7b-instruct/ggml-model-f16.gguf ./models/codellama-7b-instruct/ggml-model-q8_0.gguf q8_0
fi
if [ "$1" -eq "6" ]; then
cd /workspace
git-lfs-download https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf --without *safetensors*
rm -v ./CodeLlama-13b-Instruct-hf/*safetensors*
cd /workspace/llama.cpp
python3 convert.py ./models/codellama-13b-instruct --outfile ./models/codellama-13b-instruct/ggml-model-f16.gguf --outtype f16
./quantize ./models/codellama-13b-instruct/ggml-model-f16.gguf ./models/codellama-13b-instruct/ggml-model-q4_0.gguf q4_0
./quantize ./models/codellama-13b-instruct/ggml-model-f16.gguf ./models/codellama-13b-instruct/ggml-model-q4_k.gguf q4_k
./quantize ./models/codellama-13b-instruct/ggml-model-f16.gguf ./models/codellama-13b-instruct/ggml-model-q8_0.gguf q8_0
fi
if [ "$1" -eq "7" ]; then
cd /workspace
git-lfs-download https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf --without *safetensors*
rm -v ./CodeLlama-34b-Instruct-hf/*safetensors*
cd /workspace/llama.cpp
python3 convert.py ./models/codellama-34b-instruct --outfile ./models/codellama-34b-instruct/ggml-model-f16.gguf --outtype f16
./quantize ./models/codellama-34b-instruct/ggml-model-f16.gguf ./models/codellama-34b-instruct/ggml-model-q4_0.gguf q4_0
./quantize ./models/codellama-34b-instruct/ggml-model-f16.gguf ./models/codellama-34b-instruct/ggml-model-q4_k.gguf q4_k
./quantize ./models/codellama-34b-instruct/ggml-model-f16.gguf ./models/codellama-34b-instruct/ggml-model-q8_0.gguf q8_0
fi
if [ "$1" -eq "1" ]; then
# perf + perplexity
cd /workspace/llama.cpp/build-cublas
make -j && ../scripts/run-all-perf.sh tinyllama-1b "f16" "-ngl 99 -t 1 -p 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,64,128,256,512,1024,2048 -n 128"
../scripts/get-wikitext-2.sh
unzip wikitext-2-raw-v1.zip
make -j && ./bin/perplexity -m ../models/tinyllama-1b/ggml-model-f16.gguf -f ./wikitext-2-raw/wiki.test.raw -ngl 100 --chunks 32
# batched
cd /workspace/llama.cpp
LLAMA_CUBLAS=1 make -j && ./batched ./models/tinyllama-1b/ggml-model-f16.gguf "Hello, my name is" 8 128 999
# batched-bench
cd /workspace/llama.cpp
LLAMA_CUBLAS=1 make -j && ./batched-bench ./models/tinyllama-1b/ggml-model-f16.gguf 4608 1 99 0 512 128 1,2,3,4,5,6,7,8,16,32
# parallel
cd /workspace/llama.cpp
LLAMA_CUBLAS=1 make -j && ./parallel -m ./models/tinyllama-1b/ggml-model-f16.gguf -t 1 -ngl 100 -c 4096 -b 512 -s 1 -np 8 -ns 128 -n 100 -cb
fi
# speculative
#if [ "$1" -eq "7" ]; then
# cd /workspace/llama.cpp
#
# LLAMA_CUBLAS=1 make -j && ./speculative -m ./models/codellama-34b-instruct/ggml-model-f16.gguf -md ./models/codellama-7b-instruct/ggml-model-q4_0.gguf -p "# Dijkstra's shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n" -e -ngl 999 -ngld 999 -t 4 -n 512 -c 4096 -s 21 --draft 16 -np 1 --temp 0.0
#fi
# more benches
#LLAMA_CUBLAS=1 make -j && ./batched-bench ./models/codellama-7b/ggml-model-q4_k.gguf 4096 1 99 1 512,3200 128,128,800 1
#LLAMA_CUBLAS=1 make -j && ./batched-bench ./models/codellama-13b/ggml-model-q4_k.gguf 4096 1 99 1 512,3200 128,128,800 1

View File

@@ -1 +1 @@
b458250b736a7473f7ff3560d47c93f1644f3290
8695910a39102609073d0e099aa7c97d6bcb3bf9

View File

@@ -1412,6 +1412,50 @@ struct test_pad : public test_case {
}
};
// GGML_OP_ARANGE
struct test_arange : public test_case {
const ggml_type type;
const float start;
const float stop;
const float step;
std::string vars() override {
return VARS_TO_STR4(type, start, stop, step);
}
test_arange(ggml_type type = GGML_TYPE_F32,
float start = 0.f, float stop = 10.f, float step = 1.f)
: type(type), start(start), stop(stop), step(step) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * out = ggml_arange(ctx, start, stop, step);
return out;
}
};
// GGML_OP_TIMESTEP_EMBEDDING
struct test_timestep_embedding : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne_a;
const int dim;
const int max_period;
std::string vars() override {
return VARS_TO_STR4(type, ne_a, dim, max_period);
}
test_timestep_embedding(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne_a = {2, 1, 1, 1},
int dim = 320, int max_period=10000)
: type(type), ne_a(ne_a), dim(dim), max_period(max_period) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_tensor * out = ggml_timestep_embedding(ctx, a, dim, max_period);
return out;
}
};
// GGML_OP_LEAKY_RELU
struct test_leaky_relu : public test_case {
const ggml_type type;
@@ -2126,6 +2170,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_group_norm());
test_cases.emplace_back(new test_acc());
test_cases.emplace_back(new test_pad());
test_cases.emplace_back(new test_arange());
test_cases.emplace_back(new test_timestep_embedding());
test_cases.emplace_back(new test_leaky_relu());
// these tests are disabled to save execution time, but they can be handy for debugging